Add role-based filtering and imporve code

This commit is contained in:
Urtzi Alfaro
2025-10-15 16:12:49 +02:00
parent 96ad5c6692
commit 8f9e9a7edc
158 changed files with 11033 additions and 1544 deletions

View File

@@ -8,11 +8,19 @@ from typing import Optional, Dict, Any
import structlog
from datetime import datetime, timezone
import uuid
import shared.redis_utils
from shared.routing import RouteBuilder
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from shared.database.base import create_database_manager
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import require_user_role, admin_role_required
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
from shared.subscription.plans import (
get_training_job_quota,
get_dataset_size_limit
)
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
@@ -20,6 +28,11 @@ from app.schemas.training import (
SingleProductTrainingRequest,
TrainingJobResponse
)
from app.utils.time_estimation import (
calculate_initial_estimate,
calculate_estimated_completion_time,
get_historical_average_estimate
)
from app.services.training_events import (
publish_training_started,
publish_training_completed,
@@ -32,6 +45,30 @@ route_builder = RouteBuilder('training')
router = APIRouter(tags=["training-operations"])
# Initialize audit logger
audit_logger = create_audit_logger("training-service")
# Redis client for rate limiting
_redis_client = None
async def get_training_redis_client():
"""Get or create Redis client for rate limiting"""
global _redis_client
if _redis_client is None:
# Initialize Redis if not already done
try:
from app.core.config import settings
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
except:
# Fallback to getting the client directly (if already initialized elsewhere)
_redis_client = await shared.redis_utils.get_redis_client()
return _redis_client
async def get_rate_limiter():
"""Dependency for rate limiter"""
redis_client = await get_training_redis_client()
return create_rate_limiter(redis_client)
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
@@ -40,31 +77,82 @@ def get_enhanced_training_service():
@router.post(
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
@require_user_role(['admin', 'owner'])
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
async def start_training_job(
request: TrainingJobRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
current_user: Dict[str, Any] = Depends(get_current_user_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service),
rate_limiter = Depends(get_rate_limiter)
):
"""
Start a new training job for all tenant products using repository pattern.
Start a new training job for all tenant products (Admin+ only, quota enforced).
**RBAC:** Admin or Owner role required
**Quotas:**
- Starter: 1 training job/day, max 1,000 rows
- Professional: 5 training jobs/day, max 10,000 rows
- Enterprise: Unlimited jobs, unlimited rows
Enhanced immediate response pattern:
1. Validate request with enhanced validation
2. Create job record using repository pattern
3. Return 200 with enhanced job details
4. Execute enhanced training in background with repository tracking
1. Validate subscription tier and quotas
2. Validate request with enhanced validation
3. Create job record using repository pattern
4. Return 200 with enhanced job details
5. Execute enhanced training in background with repository tracking
Enhanced features:
- Repository pattern for data access
- Quota enforcement by subscription tier
- Audit logging for all operations
- Enhanced error handling and logging
- Metrics tracking and monitoring
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
# Get subscription tier and enforce quotas
tier = current_user.get('subscription_tier', 'starter')
# Estimate dataset size (this should come from the request or be calculated)
# For now, we'll assume a reasonable estimate
estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500
# Initialize variables for later use
quota_result = None
quota_limit = None
try:
# Validate dataset size limits
await rate_limiter.validate_dataset_size(
tenant_id, estimated_dataset_size, tier
)
# Check daily training job quota
quota_limit = get_training_job_quota(tier)
quota_result = await rate_limiter.check_and_increment_quota(
tenant_id,
"training_jobs",
quota_limit,
period=86400 # 24 hours
)
logger.info("Training job quota check passed",
tenant_id=tenant_id,
tier=tier,
current_usage=quota_result.get('current', 0) if quota_result else 0,
limit=quota_limit)
except HTTPException:
# Quota or validation error - re-raise
raise
except Exception as quota_error:
logger.error("Quota validation failed", error=str(quota_error))
# Continue with job creation but log the error
try:
# Generate enhanced job ID
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
@@ -85,6 +173,25 @@ async def start_training_job(
total_products=0 # Will be updated when actual training starts
)
# Calculate intelligent time estimate
# We don't know exact product count yet, so use historical average or estimate
try:
# Try to get historical average for this tenant
from app.core.database import get_db
db = next(get_db())
historical_avg = get_historical_average_estimate(db, tenant_id)
# If no historical data, estimate based on typical product count (10-20 products)
estimated_products = 15 # Conservative estimate
estimated_duration_minutes = calculate_initial_estimate(
total_products=estimated_products,
avg_training_time_per_product=historical_avg if historical_avg else 60.0
)
except Exception as est_error:
logger.warning("Could not calculate intelligent estimate, using default",
error=str(est_error))
estimated_duration_minutes = 15 # Default fallback
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
@@ -92,7 +199,8 @@ async def start_training_job(
job_id=job_id,
bakery_location=(40.4168, -3.7038),
requested_start=request.start_date,
requested_end=request.end_date
requested_end=request.end_date,
estimated_duration_minutes=estimated_duration_minutes
)
# Return enhanced immediate success response
@@ -102,7 +210,7 @@ async def start_training_job(
"status": "pending",
"message": "Enhanced training job started successfully using repository pattern",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 18,
"estimated_duration_minutes": estimated_duration_minutes,
"training_results": {
"total_products": 0,
"successful_trainings": 0,
@@ -126,6 +234,32 @@ async def start_training_job(
job_id=job_id,
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
# Log audit event for training job creation
try:
from app.core.database import get_db
db = next(get_db())
await audit_logger.log_event(
db_session=db,
tenant_id=tenant_id,
user_id=current_user["user_id"],
action=AuditAction.CREATE.value,
resource_type="training_job",
resource_id=job_id,
severity=AuditSeverity.MEDIUM.value,
description=f"Started training job (tier: {tier})",
metadata={
"job_id": job_id,
"tier": tier,
"estimated_dataset_size": estimated_dataset_size,
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
"quota_limit": quota_limit if quota_limit else "unlimited"
},
endpoint="/jobs",
method="POST"
)
except Exception as audit_error:
logger.warning("Failed to log audit event", error=str(audit_error))
return TrainingJobResponse(**response_data)
except HTTPException:
@@ -157,7 +291,8 @@ async def execute_training_job_background(
job_id: str,
bakery_location: tuple,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None
requested_end: Optional[datetime] = None,
estimated_duration_minutes: int = 15
):
"""
Enhanced background task that executes the training job using repository pattern.
@@ -202,7 +337,7 @@ async def execute_training_job_background(
},
"requested_start": requested_start.isoformat() if requested_start else None,
"requested_end": requested_end.isoformat() if requested_end else None,
"estimated_duration_minutes": 18,
"estimated_duration_minutes": estimated_duration_minutes,
"background_execution": True,
"enhanced_features": True,
"repository_pattern": True,
@@ -278,16 +413,20 @@ async def execute_training_job_background(
@router.post(
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
@require_user_role(['admin', 'owner'])
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
async def start_single_product_training(
request: SingleProductTrainingRequest,
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: str = Path(..., description="Inventory product UUID"),
request_obj: Request = None,
current_user: Dict[str, Any] = Depends(get_current_user_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start enhanced training for a single product using repository pattern.
Start enhanced training for a single product (Admin+ only).
**RBAC:** Admin or Owner role required
Enhanced features:
- Repository pattern for data access

View File

@@ -4,6 +4,13 @@ Training Service Models Package
Import all models to ensure they are registered with SQLAlchemy Base.
"""
# Import AuditLog model for this service
from shared.security import create_audit_log_model
from shared.database.base import Base
# Create audit log model for this service
AuditLog = create_audit_log_model(Base)
# Import all models to register them with the Base metadata
from .training import (
TrainedModel,
@@ -20,4 +27,5 @@ __all__ = [
"ModelPerformanceMetric",
"TrainingJobQueue",
"ModelArtifact",
"AuditLog",
]

View File

@@ -193,4 +193,59 @@ class TrainedModel(Base):
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
"data_quality_score": self.data_quality_score
}
class TrainingPerformanceMetrics(Base):
"""
Table to track historical training performance for time estimation.
Stores aggregated metrics from completed training jobs.
"""
__tablename__ = "training_performance_metrics"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
job_id = Column(String(255), nullable=False, index=True)
# Training job statistics
total_products = Column(Integer, nullable=False)
successful_products = Column(Integer, nullable=False)
failed_products = Column(Integer, nullable=False)
# Time metrics
total_duration_seconds = Column(Float, nullable=False)
avg_time_per_product = Column(Float, nullable=False) # Key metric for estimation
data_analysis_time_seconds = Column(Float, nullable=True)
training_time_seconds = Column(Float, nullable=True)
finalization_time_seconds = Column(Float, nullable=True)
# Job metadata
completed_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
def __repr__(self):
return (
f"<TrainingPerformanceMetrics("
f"tenant_id={self.tenant_id}, "
f"job_id={self.job_id}, "
f"total_products={self.total_products}, "
f"avg_time_per_product={self.avg_time_per_product:.2f}s"
f")>"
)
def to_dict(self):
return {
"id": str(self.id),
"tenant_id": str(self.tenant_id),
"job_id": self.job_id,
"total_products": self.total_products,
"successful_products": self.successful_products,
"failed_products": self.failed_products,
"total_duration_seconds": self.total_duration_seconds,
"avg_time_per_product": self.avg_time_per_product,
"data_analysis_time_seconds": self.data_analysis_time_seconds,
"training_time_seconds": self.training_time_seconds,
"finalization_time_seconds": self.finalization_time_seconds,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"created_at": self.created_at.isoformat() if self.created_at else None
}

View File

@@ -112,6 +112,8 @@ class TrainingJobStatus(BaseModel):
products_completed: int = Field(0, description="Number of products completed")
products_failed: int = Field(0, description="Number of products that failed")
error_message: Optional[str] = Field(None, description="Error message if failed")
estimated_time_remaining_seconds: Optional[int] = Field(None, description="Estimated time remaining in seconds")
message: Optional[str] = Field(None, description="Optional status message")
@validator('job_id', pre=True)
def convert_uuid_to_string(cls, v):

View File

@@ -38,10 +38,19 @@ async def cleanup_messaging():
async def publish_training_started(
job_id: str,
tenant_id: str,
total_products: int
total_products: int,
estimated_duration_minutes: Optional[int] = None,
estimated_completion_time: Optional[str] = None
) -> bool:
"""
Event 1: Training Started (0% progress)
Args:
job_id: Training job identifier
tenant_id: Tenant identifier
total_products: Number of products to train
estimated_duration_minutes: Estimated time to completion in minutes
estimated_completion_time: ISO timestamp of estimated completion
"""
event_data = {
"service_name": "training-service",
@@ -53,7 +62,10 @@ async def publish_training_started(
"progress": 0,
"current_step": "Training Started",
"step_details": f"Starting training for {total_products} products",
"total_products": total_products
"total_products": total_products,
"estimated_duration_minutes": estimated_duration_minutes,
"estimated_completion_time": estimated_completion_time,
"estimated_time_remaining_seconds": estimated_duration_minutes * 60 if estimated_duration_minutes else None
}
}
@@ -67,7 +79,8 @@ async def publish_training_started(
logger.info("Published training started event",
job_id=job_id,
tenant_id=tenant_id,
total_products=total_products)
total_products=total_products,
estimated_duration_minutes=estimated_duration_minutes)
else:
logger.error("Failed to publish training started event", job_id=job_id)
@@ -77,10 +90,17 @@ async def publish_training_started(
async def publish_data_analysis(
job_id: str,
tenant_id: str,
analysis_details: Optional[str] = None
analysis_details: Optional[str] = None,
estimated_time_remaining_seconds: Optional[int] = None
) -> bool:
"""
Event 2: Data Analysis (20% progress)
Args:
job_id: Training job identifier
tenant_id: Tenant identifier
analysis_details: Details about the analysis
estimated_time_remaining_seconds: Estimated time remaining in seconds
"""
event_data = {
"service_name": "training-service",
@@ -91,7 +111,8 @@ async def publish_data_analysis(
"tenant_id": tenant_id,
"progress": 20,
"current_step": "Data Analysis",
"step_details": analysis_details or "Analyzing sales, weather, and traffic data"
"step_details": analysis_details or "Analyzing sales, weather, and traffic data",
"estimated_time_remaining_seconds": estimated_time_remaining_seconds
}
}
@@ -116,7 +137,8 @@ async def publish_product_training_completed(
tenant_id: str,
product_name: str,
products_completed: int,
total_products: int
total_products: int,
estimated_time_remaining_seconds: Optional[int] = None
) -> bool:
"""
Event 3: Product Training Completed (contributes to 20-80% progress)
@@ -124,6 +146,14 @@ async def publish_product_training_completed(
This event is published each time a product training completes.
The frontend/consumer will calculate the progress as:
progress = 20 + (products_completed / total_products) * 60
Args:
job_id: Training job identifier
tenant_id: Tenant identifier
product_name: Name of the product that was trained
products_completed: Number of products completed so far
total_products: Total number of products
estimated_time_remaining_seconds: Estimated time remaining in seconds
"""
event_data = {
"service_name": "training-service",
@@ -136,7 +166,8 @@ async def publish_product_training_completed(
"products_completed": products_completed,
"total_products": total_products,
"current_step": "Model Training",
"step_details": f"Completed training for {product_name} ({products_completed}/{total_products})"
"step_details": f"Completed training for {product_name} ({products_completed}/{total_products})",
"estimated_time_remaining_seconds": estimated_time_remaining_seconds
}
}

View File

@@ -452,23 +452,50 @@ class EnhancedTrainingService:
try:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
log = await self.training_log_repo.get_log_by_job_id(job_id)
if not log:
return {"error": "Job not found"}
# Calculate estimated time remaining based on progress and elapsed time
estimated_time_remaining_seconds = None
if log.status == "running" and log.progress > 0 and log.start_time:
from datetime import datetime, timezone
elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds()
if elapsed_time > 0:
# Calculate estimated total time based on progress
estimated_total_time = (elapsed_time / log.progress) * 100
estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time)
# Cap at reasonable maximum (e.g., 30 minutes)
estimated_time_remaining_seconds = min(estimated_time_remaining_seconds, 1800)
# Extract products info from results if available
products_total = 0
products_completed = 0
products_failed = 0
if log.results:
products_total = log.results.get("total_products", 0)
products_completed = log.results.get("successful_trainings", 0)
products_failed = log.results.get("failed_trainings", 0)
return {
"job_id": job_id,
"tenant_id": log.tenant_id,
"status": log.status,
"progress": log.progress,
"current_step": log.current_step,
"start_time": log.start_time.isoformat() if log.start_time else None,
"end_time": log.end_time.isoformat() if log.end_time else None,
"started_at": log.start_time.isoformat() if log.start_time else None,
"completed_at": log.end_time.isoformat() if log.end_time else None,
"error_message": log.error_message,
"results": log.results
"results": log.results,
"products_total": products_total,
"products_completed": products_completed,
"products_failed": products_failed,
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"message": log.current_step
}
except Exception as e:
logger.error("Failed to get training status",
job_id=job_id,

View File

@@ -0,0 +1,332 @@
"""
Training Time Estimation Utilities
Provides intelligent time estimation for training jobs based on:
- Product count
- Historical performance data
- Current progress and throughput
"""
from typing import List, Optional
from datetime import datetime, timedelta, timezone
import structlog
logger = structlog.get_logger()
def calculate_initial_estimate(
total_products: int,
avg_training_time_per_product: float = 60.0, # seconds, default 1 min/product
data_analysis_overhead: float = 120.0, # seconds, data loading & analysis
finalization_overhead: float = 60.0, # seconds, saving models & cleanup
min_estimate_minutes: int = 5,
max_estimate_minutes: int = 60
) -> int:
"""
Calculate realistic initial time estimate for training job.
Formula:
total_time = data_analysis + (products * avg_time_per_product) + finalization
Args:
total_products: Number of products to train
avg_training_time_per_product: Average time per product in seconds
data_analysis_overhead: Time for data loading and analysis in seconds
finalization_overhead: Time for saving models and cleanup in seconds
min_estimate_minutes: Minimum estimate (prevents unrealistic low values)
max_estimate_minutes: Maximum estimate (prevents unrealistic high values)
Returns:
Estimated duration in minutes
Examples:
>>> calculate_initial_estimate(1)
4 # 120 + 60 + 60 = 240s = 4min
>>> calculate_initial_estimate(5)
8 # 120 + 300 + 60 = 480s = 8min
>>> calculate_initial_estimate(10)
13 # 120 + 600 + 60 = 780s = 13min
>>> calculate_initial_estimate(20)
23 # 120 + 1200 + 60 = 1380s = 23min
>>> calculate_initial_estimate(100)
60 # Capped at max (would be 103 min)
"""
# Calculate total estimated time in seconds
estimated_seconds = (
data_analysis_overhead +
(total_products * avg_training_time_per_product) +
finalization_overhead
)
# Convert to minutes, round up
estimated_minutes = int((estimated_seconds / 60) + 0.5)
# Apply min/max bounds
estimated_minutes = max(min_estimate_minutes, min(max_estimate_minutes, estimated_minutes))
logger.info(
"Calculated initial time estimate",
total_products=total_products,
estimated_seconds=estimated_seconds,
estimated_minutes=estimated_minutes,
avg_time_per_product=avg_training_time_per_product
)
return estimated_minutes
def calculate_estimated_completion_time(
estimated_duration_minutes: int,
start_time: Optional[datetime] = None
) -> datetime:
"""
Calculate estimated completion timestamp.
Args:
estimated_duration_minutes: Estimated duration in minutes
start_time: Job start time (defaults to now)
Returns:
Estimated completion datetime (timezone-aware UTC)
"""
if start_time is None:
start_time = datetime.now(timezone.utc)
completion_time = start_time + timedelta(minutes=estimated_duration_minutes)
return completion_time
def calculate_remaining_time_smart(
progress: int,
elapsed_time: float,
products_completed: int,
total_products: int,
recent_product_times: Optional[List[float]] = None,
max_remaining_seconds: int = 1800 # 30 minutes
) -> Optional[int]:
"""
Calculate remaining time using smart algorithm that considers:
- Current progress percentage
- Actual throughput (products completed / elapsed time)
- Recent performance (weighted moving average)
Args:
progress: Current progress percentage (0-100)
elapsed_time: Time elapsed since job start (seconds)
products_completed: Number of products completed
total_products: Total number of products
recent_product_times: List of recent product training times (seconds)
max_remaining_seconds: Maximum remaining time (safety cap)
Returns:
Estimated remaining time in seconds, or None if can't calculate
"""
# Job completed or not started
if progress >= 100 or progress <= 0:
return None
# Early stage (0-20%): Use weighted estimate
if progress <= 20:
# In data analysis phase - estimate based on remaining products
remaining_products = total_products - products_completed
if recent_product_times and len(recent_product_times) > 0:
# Use recent performance if available
avg_time_per_product = sum(recent_product_times) / len(recent_product_times)
else:
# Fallback to default
avg_time_per_product = 60.0 # 1 minute per product
# Estimate: remaining products * avg time + overhead
estimated_remaining = (remaining_products * avg_time_per_product) + 60.0 # +1 min overhead
logger.debug(
"Early stage estimation",
progress=progress,
remaining_products=remaining_products,
avg_time_per_product=avg_time_per_product,
estimated_remaining=estimated_remaining
)
# Mid/late stage (21-99%): Use actual throughput
else:
if products_completed > 0:
# Calculate actual time per product from current run
actual_time_per_product = elapsed_time / products_completed
remaining_products = total_products - products_completed
estimated_remaining = remaining_products * actual_time_per_product
logger.debug(
"Mid/late stage estimation",
progress=progress,
products_completed=products_completed,
total_products=total_products,
actual_time_per_product=actual_time_per_product,
estimated_remaining=estimated_remaining
)
else:
# Fallback to linear extrapolation
estimated_total = (elapsed_time / progress) * 100
estimated_remaining = estimated_total - elapsed_time
logger.debug(
"Fallback linear estimation",
progress=progress,
elapsed_time=elapsed_time,
estimated_remaining=estimated_remaining
)
# Apply safety cap
estimated_remaining = min(estimated_remaining, max_remaining_seconds)
return int(estimated_remaining)
def calculate_average_product_time(
products_completed: int,
elapsed_time: float,
min_products_threshold: int = 3
) -> Optional[float]:
"""
Calculate average time per product from current job progress.
Args:
products_completed: Number of products completed
elapsed_time: Time elapsed since job start (seconds)
min_products_threshold: Minimum products needed for reliable calculation
Returns:
Average time per product in seconds, or None if insufficient data
"""
if products_completed < min_products_threshold:
return None
avg_time = elapsed_time / products_completed
logger.debug(
"Calculated average product time",
products_completed=products_completed,
elapsed_time=elapsed_time,
avg_time=avg_time
)
return avg_time
def format_time_remaining(seconds: int) -> str:
"""
Format remaining time in human-readable format.
Args:
seconds: Time in seconds
Returns:
Formatted string (e.g., "5 minutes", "1 hour 23 minutes")
Examples:
>>> format_time_remaining(45)
"45 seconds"
>>> format_time_remaining(180)
"3 minutes"
>>> format_time_remaining(5400)
"1 hour 30 minutes"
"""
if seconds < 60:
return f"{seconds} seconds"
minutes = seconds // 60
remaining_seconds = seconds % 60
if minutes < 60:
if remaining_seconds > 0:
return f"{minutes} minutes {remaining_seconds} seconds"
return f"{minutes} minutes"
hours = minutes // 60
remaining_minutes = minutes % 60
if remaining_minutes > 0:
return f"{hours} hour{'s' if hours > 1 else ''} {remaining_minutes} minutes"
return f"{hours} hour{'s' if hours > 1 else ''}"
def get_historical_average_estimate(
db_session,
tenant_id: str,
lookback_days: int = 30,
limit: int = 10
) -> Optional[float]:
"""
Get historical average training time per product for a tenant.
This function queries the TrainingPerformanceMetrics table to get
recent historical data and calculate an average.
Args:
db_session: Database session
tenant_id: Tenant UUID
lookback_days: How many days back to look
limit: Maximum number of historical records to consider
Returns:
Average time per product in seconds, or None if no historical data
"""
try:
from app.models.training import TrainingPerformanceMetrics
from datetime import timedelta
cutoff = datetime.now(timezone.utc) - timedelta(days=lookback_days)
# Query recent training performance metrics
metrics = db_session.query(TrainingPerformanceMetrics).filter(
TrainingPerformanceMetrics.tenant_id == tenant_id,
TrainingPerformanceMetrics.completed_at >= cutoff
).order_by(
TrainingPerformanceMetrics.completed_at.desc()
).limit(limit).all()
if not metrics:
logger.info(
"No historical training data found",
tenant_id=tenant_id,
lookback_days=lookback_days
)
return None
# Calculate weighted average (more recent = higher weight)
total_weight = 0
weighted_sum = 0
for i, metric in enumerate(metrics):
# Weight: newer records get higher weight
weight = limit - i
weighted_sum += metric.avg_time_per_product * weight
total_weight += weight
if total_weight == 0:
return None
weighted_avg = weighted_sum / total_weight
logger.info(
"Calculated historical average",
tenant_id=tenant_id,
records_used=len(metrics),
weighted_avg=weighted_avg
)
return weighted_avg
except Exception as e:
logger.error(
"Error getting historical average",
tenant_id=tenant_id,
error=str(e)
)
return None