Add role-based filtering and imporve code
This commit is contained in:
@@ -8,11 +8,19 @@ from typing import Optional, Dict, Any
|
||||
import structlog
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
import shared.redis_utils
|
||||
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role, admin_role_required
|
||||
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
|
||||
from shared.subscription.plans import (
|
||||
get_training_job_quota,
|
||||
get_dataset_size_limit
|
||||
)
|
||||
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from app.schemas.training import (
|
||||
@@ -20,6 +28,11 @@ from app.schemas.training import (
|
||||
SingleProductTrainingRequest,
|
||||
TrainingJobResponse
|
||||
)
|
||||
from app.utils.time_estimation import (
|
||||
calculate_initial_estimate,
|
||||
calculate_estimated_completion_time,
|
||||
get_historical_average_estimate
|
||||
)
|
||||
from app.services.training_events import (
|
||||
publish_training_started,
|
||||
publish_training_completed,
|
||||
@@ -32,6 +45,30 @@ route_builder = RouteBuilder('training')
|
||||
|
||||
router = APIRouter(tags=["training-operations"])
|
||||
|
||||
# Initialize audit logger
|
||||
audit_logger = create_audit_logger("training-service")
|
||||
|
||||
# Redis client for rate limiting
|
||||
_redis_client = None
|
||||
|
||||
async def get_training_redis_client():
|
||||
"""Get or create Redis client for rate limiting"""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
# Initialize Redis if not already done
|
||||
try:
|
||||
from app.core.config import settings
|
||||
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
except:
|
||||
# Fallback to getting the client directly (if already initialized elsewhere)
|
||||
_redis_client = await shared.redis_utils.get_redis_client()
|
||||
return _redis_client
|
||||
|
||||
async def get_rate_limiter():
|
||||
"""Dependency for rate limiter"""
|
||||
redis_client = await get_training_redis_client()
|
||||
return create_rate_limiter(redis_client)
|
||||
|
||||
def get_enhanced_training_service():
|
||||
"""Dependency injection for EnhancedTrainingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
@@ -40,31 +77,82 @@ def get_enhanced_training_service():
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
|
||||
async def start_training_job(
|
||||
request: TrainingJobRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service),
|
||||
rate_limiter = Depends(get_rate_limiter)
|
||||
):
|
||||
"""
|
||||
Start a new training job for all tenant products using repository pattern.
|
||||
Start a new training job for all tenant products (Admin+ only, quota enforced).
|
||||
|
||||
**RBAC:** Admin or Owner role required
|
||||
**Quotas:**
|
||||
- Starter: 1 training job/day, max 1,000 rows
|
||||
- Professional: 5 training jobs/day, max 10,000 rows
|
||||
- Enterprise: Unlimited jobs, unlimited rows
|
||||
|
||||
Enhanced immediate response pattern:
|
||||
1. Validate request with enhanced validation
|
||||
2. Create job record using repository pattern
|
||||
3. Return 200 with enhanced job details
|
||||
4. Execute enhanced training in background with repository tracking
|
||||
1. Validate subscription tier and quotas
|
||||
2. Validate request with enhanced validation
|
||||
3. Create job record using repository pattern
|
||||
4. Return 200 with enhanced job details
|
||||
5. Execute enhanced training in background with repository tracking
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
- Quota enforcement by subscription tier
|
||||
- Audit logging for all operations
|
||||
- Enhanced error handling and logging
|
||||
- Metrics tracking and monitoring
|
||||
- Transactional operations
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
# Get subscription tier and enforce quotas
|
||||
tier = current_user.get('subscription_tier', 'starter')
|
||||
|
||||
# Estimate dataset size (this should come from the request or be calculated)
|
||||
# For now, we'll assume a reasonable estimate
|
||||
estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500
|
||||
|
||||
# Initialize variables for later use
|
||||
quota_result = None
|
||||
quota_limit = None
|
||||
|
||||
try:
|
||||
# Validate dataset size limits
|
||||
await rate_limiter.validate_dataset_size(
|
||||
tenant_id, estimated_dataset_size, tier
|
||||
)
|
||||
|
||||
# Check daily training job quota
|
||||
quota_limit = get_training_job_quota(tier)
|
||||
quota_result = await rate_limiter.check_and_increment_quota(
|
||||
tenant_id,
|
||||
"training_jobs",
|
||||
quota_limit,
|
||||
period=86400 # 24 hours
|
||||
)
|
||||
|
||||
logger.info("Training job quota check passed",
|
||||
tenant_id=tenant_id,
|
||||
tier=tier,
|
||||
current_usage=quota_result.get('current', 0) if quota_result else 0,
|
||||
limit=quota_limit)
|
||||
|
||||
except HTTPException:
|
||||
# Quota or validation error - re-raise
|
||||
raise
|
||||
except Exception as quota_error:
|
||||
logger.error("Quota validation failed", error=str(quota_error))
|
||||
# Continue with job creation but log the error
|
||||
|
||||
try:
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
@@ -85,6 +173,25 @@ async def start_training_job(
|
||||
total_products=0 # Will be updated when actual training starts
|
||||
)
|
||||
|
||||
# Calculate intelligent time estimate
|
||||
# We don't know exact product count yet, so use historical average or estimate
|
||||
try:
|
||||
# Try to get historical average for this tenant
|
||||
from app.core.database import get_db
|
||||
db = next(get_db())
|
||||
historical_avg = get_historical_average_estimate(db, tenant_id)
|
||||
|
||||
# If no historical data, estimate based on typical product count (10-20 products)
|
||||
estimated_products = 15 # Conservative estimate
|
||||
estimated_duration_minutes = calculate_initial_estimate(
|
||||
total_products=estimated_products,
|
||||
avg_training_time_per_product=historical_avg if historical_avg else 60.0
|
||||
)
|
||||
except Exception as est_error:
|
||||
logger.warning("Could not calculate intelligent estimate, using default",
|
||||
error=str(est_error))
|
||||
estimated_duration_minutes = 15 # Default fallback
|
||||
|
||||
# Add enhanced background task
|
||||
background_tasks.add_task(
|
||||
execute_training_job_background,
|
||||
@@ -92,7 +199,8 @@ async def start_training_job(
|
||||
job_id=job_id,
|
||||
bakery_location=(40.4168, -3.7038),
|
||||
requested_start=request.start_date,
|
||||
requested_end=request.end_date
|
||||
requested_end=request.end_date,
|
||||
estimated_duration_minutes=estimated_duration_minutes
|
||||
)
|
||||
|
||||
# Return enhanced immediate success response
|
||||
@@ -102,7 +210,7 @@ async def start_training_job(
|
||||
"status": "pending",
|
||||
"message": "Enhanced training job started successfully using repository pattern",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": 18,
|
||||
"estimated_duration_minutes": estimated_duration_minutes,
|
||||
"training_results": {
|
||||
"total_products": 0,
|
||||
"successful_trainings": 0,
|
||||
@@ -126,6 +234,32 @@ async def start_training_job(
|
||||
job_id=job_id,
|
||||
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
|
||||
|
||||
# Log audit event for training job creation
|
||||
try:
|
||||
from app.core.database import get_db
|
||||
db = next(get_db())
|
||||
await audit_logger.log_event(
|
||||
db_session=db,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
action=AuditAction.CREATE.value,
|
||||
resource_type="training_job",
|
||||
resource_id=job_id,
|
||||
severity=AuditSeverity.MEDIUM.value,
|
||||
description=f"Started training job (tier: {tier})",
|
||||
metadata={
|
||||
"job_id": job_id,
|
||||
"tier": tier,
|
||||
"estimated_dataset_size": estimated_dataset_size,
|
||||
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
|
||||
"quota_limit": quota_limit if quota_limit else "unlimited"
|
||||
},
|
||||
endpoint="/jobs",
|
||||
method="POST"
|
||||
)
|
||||
except Exception as audit_error:
|
||||
logger.warning("Failed to log audit event", error=str(audit_error))
|
||||
|
||||
return TrainingJobResponse(**response_data)
|
||||
|
||||
except HTTPException:
|
||||
@@ -157,7 +291,8 @@ async def execute_training_job_background(
|
||||
job_id: str,
|
||||
bakery_location: tuple,
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None
|
||||
requested_end: Optional[datetime] = None,
|
||||
estimated_duration_minutes: int = 15
|
||||
):
|
||||
"""
|
||||
Enhanced background task that executes the training job using repository pattern.
|
||||
@@ -202,7 +337,7 @@ async def execute_training_job_background(
|
||||
},
|
||||
"requested_start": requested_start.isoformat() if requested_start else None,
|
||||
"requested_end": requested_end.isoformat() if requested_end else None,
|
||||
"estimated_duration_minutes": 18,
|
||||
"estimated_duration_minutes": estimated_duration_minutes,
|
||||
"background_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
@@ -278,16 +413,20 @@ async def execute_training_job_background(
|
||||
|
||||
@router.post(
|
||||
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
|
||||
async def start_single_product_training(
|
||||
request: SingleProductTrainingRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
inventory_product_id: str = Path(..., description="Inventory product UUID"),
|
||||
request_obj: Request = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Start enhanced training for a single product using repository pattern.
|
||||
Start enhanced training for a single product (Admin+ only).
|
||||
|
||||
**RBAC:** Admin or Owner role required
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
|
||||
@@ -4,6 +4,13 @@ Training Service Models Package
|
||||
Import all models to ensure they are registered with SQLAlchemy Base.
|
||||
"""
|
||||
|
||||
# Import AuditLog model for this service
|
||||
from shared.security import create_audit_log_model
|
||||
from shared.database.base import Base
|
||||
|
||||
# Create audit log model for this service
|
||||
AuditLog = create_audit_log_model(Base)
|
||||
|
||||
# Import all models to register them with the Base metadata
|
||||
from .training import (
|
||||
TrainedModel,
|
||||
@@ -20,4 +27,5 @@ __all__ = [
|
||||
"ModelPerformanceMetric",
|
||||
"TrainingJobQueue",
|
||||
"ModelArtifact",
|
||||
"AuditLog",
|
||||
]
|
||||
|
||||
@@ -193,4 +193,59 @@ class TrainedModel(Base):
|
||||
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
|
||||
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
|
||||
"data_quality_score": self.data_quality_score
|
||||
}
|
||||
|
||||
|
||||
class TrainingPerformanceMetrics(Base):
|
||||
"""
|
||||
Table to track historical training performance for time estimation.
|
||||
Stores aggregated metrics from completed training jobs.
|
||||
"""
|
||||
__tablename__ = "training_performance_metrics"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
job_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Training job statistics
|
||||
total_products = Column(Integer, nullable=False)
|
||||
successful_products = Column(Integer, nullable=False)
|
||||
failed_products = Column(Integer, nullable=False)
|
||||
|
||||
# Time metrics
|
||||
total_duration_seconds = Column(Float, nullable=False)
|
||||
avg_time_per_product = Column(Float, nullable=False) # Key metric for estimation
|
||||
data_analysis_time_seconds = Column(Float, nullable=True)
|
||||
training_time_seconds = Column(Float, nullable=True)
|
||||
finalization_time_seconds = Column(Float, nullable=True)
|
||||
|
||||
# Job metadata
|
||||
completed_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<TrainingPerformanceMetrics("
|
||||
f"tenant_id={self.tenant_id}, "
|
||||
f"job_id={self.job_id}, "
|
||||
f"total_products={self.total_products}, "
|
||||
f"avg_time_per_product={self.avg_time_per_product:.2f}s"
|
||||
f")>"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"tenant_id": str(self.tenant_id),
|
||||
"job_id": self.job_id,
|
||||
"total_products": self.total_products,
|
||||
"successful_products": self.successful_products,
|
||||
"failed_products": self.failed_products,
|
||||
"total_duration_seconds": self.total_duration_seconds,
|
||||
"avg_time_per_product": self.avg_time_per_product,
|
||||
"data_analysis_time_seconds": self.data_analysis_time_seconds,
|
||||
"training_time_seconds": self.training_time_seconds,
|
||||
"finalization_time_seconds": self.finalization_time_seconds,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||
}
|
||||
@@ -112,6 +112,8 @@ class TrainingJobStatus(BaseModel):
|
||||
products_completed: int = Field(0, description="Number of products completed")
|
||||
products_failed: int = Field(0, description="Number of products that failed")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
estimated_time_remaining_seconds: Optional[int] = Field(None, description="Estimated time remaining in seconds")
|
||||
message: Optional[str] = Field(None, description="Optional status message")
|
||||
|
||||
@validator('job_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
|
||||
@@ -38,10 +38,19 @@ async def cleanup_messaging():
|
||||
async def publish_training_started(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
total_products: int
|
||||
total_products: int,
|
||||
estimated_duration_minutes: Optional[int] = None,
|
||||
estimated_completion_time: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Event 1: Training Started (0% progress)
|
||||
|
||||
Args:
|
||||
job_id: Training job identifier
|
||||
tenant_id: Tenant identifier
|
||||
total_products: Number of products to train
|
||||
estimated_duration_minutes: Estimated time to completion in minutes
|
||||
estimated_completion_time: ISO timestamp of estimated completion
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
@@ -53,7 +62,10 @@ async def publish_training_started(
|
||||
"progress": 0,
|
||||
"current_step": "Training Started",
|
||||
"step_details": f"Starting training for {total_products} products",
|
||||
"total_products": total_products
|
||||
"total_products": total_products,
|
||||
"estimated_duration_minutes": estimated_duration_minutes,
|
||||
"estimated_completion_time": estimated_completion_time,
|
||||
"estimated_time_remaining_seconds": estimated_duration_minutes * 60 if estimated_duration_minutes else None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +79,8 @@ async def publish_training_started(
|
||||
logger.info("Published training started event",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=total_products)
|
||||
total_products=total_products,
|
||||
estimated_duration_minutes=estimated_duration_minutes)
|
||||
else:
|
||||
logger.error("Failed to publish training started event", job_id=job_id)
|
||||
|
||||
@@ -77,10 +90,17 @@ async def publish_training_started(
|
||||
async def publish_data_analysis(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
analysis_details: Optional[str] = None
|
||||
analysis_details: Optional[str] = None,
|
||||
estimated_time_remaining_seconds: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Event 2: Data Analysis (20% progress)
|
||||
|
||||
Args:
|
||||
job_id: Training job identifier
|
||||
tenant_id: Tenant identifier
|
||||
analysis_details: Details about the analysis
|
||||
estimated_time_remaining_seconds: Estimated time remaining in seconds
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
@@ -91,7 +111,8 @@ async def publish_data_analysis(
|
||||
"tenant_id": tenant_id,
|
||||
"progress": 20,
|
||||
"current_step": "Data Analysis",
|
||||
"step_details": analysis_details or "Analyzing sales, weather, and traffic data"
|
||||
"step_details": analysis_details or "Analyzing sales, weather, and traffic data",
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,7 +137,8 @@ async def publish_product_training_completed(
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
products_completed: int,
|
||||
total_products: int
|
||||
total_products: int,
|
||||
estimated_time_remaining_seconds: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Event 3: Product Training Completed (contributes to 20-80% progress)
|
||||
@@ -124,6 +146,14 @@ async def publish_product_training_completed(
|
||||
This event is published each time a product training completes.
|
||||
The frontend/consumer will calculate the progress as:
|
||||
progress = 20 + (products_completed / total_products) * 60
|
||||
|
||||
Args:
|
||||
job_id: Training job identifier
|
||||
tenant_id: Tenant identifier
|
||||
product_name: Name of the product that was trained
|
||||
products_completed: Number of products completed so far
|
||||
total_products: Total number of products
|
||||
estimated_time_remaining_seconds: Estimated time remaining in seconds
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
@@ -136,7 +166,8 @@ async def publish_product_training_completed(
|
||||
"products_completed": products_completed,
|
||||
"total_products": total_products,
|
||||
"current_step": "Model Training",
|
||||
"step_details": f"Completed training for {product_name} ({products_completed}/{total_products})"
|
||||
"step_details": f"Completed training for {product_name} ({products_completed}/{total_products})",
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -452,23 +452,50 @@ class EnhancedTrainingService:
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
await self._init_repositories(session)
|
||||
|
||||
|
||||
log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
if not log:
|
||||
return {"error": "Job not found"}
|
||||
|
||||
|
||||
# Calculate estimated time remaining based on progress and elapsed time
|
||||
estimated_time_remaining_seconds = None
|
||||
if log.status == "running" and log.progress > 0 and log.start_time:
|
||||
from datetime import datetime, timezone
|
||||
elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds()
|
||||
if elapsed_time > 0:
|
||||
# Calculate estimated total time based on progress
|
||||
estimated_total_time = (elapsed_time / log.progress) * 100
|
||||
estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time)
|
||||
# Cap at reasonable maximum (e.g., 30 minutes)
|
||||
estimated_time_remaining_seconds = min(estimated_time_remaining_seconds, 1800)
|
||||
|
||||
# Extract products info from results if available
|
||||
products_total = 0
|
||||
products_completed = 0
|
||||
products_failed = 0
|
||||
|
||||
if log.results:
|
||||
products_total = log.results.get("total_products", 0)
|
||||
products_completed = log.results.get("successful_trainings", 0)
|
||||
products_failed = log.results.get("failed_trainings", 0)
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"tenant_id": log.tenant_id,
|
||||
"status": log.status,
|
||||
"progress": log.progress,
|
||||
"current_step": log.current_step,
|
||||
"start_time": log.start_time.isoformat() if log.start_time else None,
|
||||
"end_time": log.end_time.isoformat() if log.end_time else None,
|
||||
"started_at": log.start_time.isoformat() if log.start_time else None,
|
||||
"completed_at": log.end_time.isoformat() if log.end_time else None,
|
||||
"error_message": log.error_message,
|
||||
"results": log.results
|
||||
"results": log.results,
|
||||
"products_total": products_total,
|
||||
"products_completed": products_completed,
|
||||
"products_failed": products_failed,
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
|
||||
"message": log.current_step
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training status",
|
||||
job_id=job_id,
|
||||
|
||||
332
services/training/app/utils/time_estimation.py
Normal file
332
services/training/app/utils/time_estimation.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
Training Time Estimation Utilities
|
||||
Provides intelligent time estimation for training jobs based on:
|
||||
- Product count
|
||||
- Historical performance data
|
||||
- Current progress and throughput
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def calculate_initial_estimate(
|
||||
total_products: int,
|
||||
avg_training_time_per_product: float = 60.0, # seconds, default 1 min/product
|
||||
data_analysis_overhead: float = 120.0, # seconds, data loading & analysis
|
||||
finalization_overhead: float = 60.0, # seconds, saving models & cleanup
|
||||
min_estimate_minutes: int = 5,
|
||||
max_estimate_minutes: int = 60
|
||||
) -> int:
|
||||
"""
|
||||
Calculate realistic initial time estimate for training job.
|
||||
|
||||
Formula:
|
||||
total_time = data_analysis + (products * avg_time_per_product) + finalization
|
||||
|
||||
Args:
|
||||
total_products: Number of products to train
|
||||
avg_training_time_per_product: Average time per product in seconds
|
||||
data_analysis_overhead: Time for data loading and analysis in seconds
|
||||
finalization_overhead: Time for saving models and cleanup in seconds
|
||||
min_estimate_minutes: Minimum estimate (prevents unrealistic low values)
|
||||
max_estimate_minutes: Maximum estimate (prevents unrealistic high values)
|
||||
|
||||
Returns:
|
||||
Estimated duration in minutes
|
||||
|
||||
Examples:
|
||||
>>> calculate_initial_estimate(1)
|
||||
4 # 120 + 60 + 60 = 240s = 4min
|
||||
|
||||
>>> calculate_initial_estimate(5)
|
||||
8 # 120 + 300 + 60 = 480s = 8min
|
||||
|
||||
>>> calculate_initial_estimate(10)
|
||||
13 # 120 + 600 + 60 = 780s = 13min
|
||||
|
||||
>>> calculate_initial_estimate(20)
|
||||
23 # 120 + 1200 + 60 = 1380s = 23min
|
||||
|
||||
>>> calculate_initial_estimate(100)
|
||||
60 # Capped at max (would be 103 min)
|
||||
"""
|
||||
# Calculate total estimated time in seconds
|
||||
estimated_seconds = (
|
||||
data_analysis_overhead +
|
||||
(total_products * avg_training_time_per_product) +
|
||||
finalization_overhead
|
||||
)
|
||||
|
||||
# Convert to minutes, round up
|
||||
estimated_minutes = int((estimated_seconds / 60) + 0.5)
|
||||
|
||||
# Apply min/max bounds
|
||||
estimated_minutes = max(min_estimate_minutes, min(max_estimate_minutes, estimated_minutes))
|
||||
|
||||
logger.info(
|
||||
"Calculated initial time estimate",
|
||||
total_products=total_products,
|
||||
estimated_seconds=estimated_seconds,
|
||||
estimated_minutes=estimated_minutes,
|
||||
avg_time_per_product=avg_training_time_per_product
|
||||
)
|
||||
|
||||
return estimated_minutes
|
||||
|
||||
|
||||
def calculate_estimated_completion_time(
|
||||
estimated_duration_minutes: int,
|
||||
start_time: Optional[datetime] = None
|
||||
) -> datetime:
|
||||
"""
|
||||
Calculate estimated completion timestamp.
|
||||
|
||||
Args:
|
||||
estimated_duration_minutes: Estimated duration in minutes
|
||||
start_time: Job start time (defaults to now)
|
||||
|
||||
Returns:
|
||||
Estimated completion datetime (timezone-aware UTC)
|
||||
"""
|
||||
if start_time is None:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
completion_time = start_time + timedelta(minutes=estimated_duration_minutes)
|
||||
|
||||
return completion_time
|
||||
|
||||
|
||||
def calculate_remaining_time_smart(
|
||||
progress: int,
|
||||
elapsed_time: float,
|
||||
products_completed: int,
|
||||
total_products: int,
|
||||
recent_product_times: Optional[List[float]] = None,
|
||||
max_remaining_seconds: int = 1800 # 30 minutes
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Calculate remaining time using smart algorithm that considers:
|
||||
- Current progress percentage
|
||||
- Actual throughput (products completed / elapsed time)
|
||||
- Recent performance (weighted moving average)
|
||||
|
||||
Args:
|
||||
progress: Current progress percentage (0-100)
|
||||
elapsed_time: Time elapsed since job start (seconds)
|
||||
products_completed: Number of products completed
|
||||
total_products: Total number of products
|
||||
recent_product_times: List of recent product training times (seconds)
|
||||
max_remaining_seconds: Maximum remaining time (safety cap)
|
||||
|
||||
Returns:
|
||||
Estimated remaining time in seconds, or None if can't calculate
|
||||
"""
|
||||
# Job completed or not started
|
||||
if progress >= 100 or progress <= 0:
|
||||
return None
|
||||
|
||||
# Early stage (0-20%): Use weighted estimate
|
||||
if progress <= 20:
|
||||
# In data analysis phase - estimate based on remaining products
|
||||
remaining_products = total_products - products_completed
|
||||
|
||||
if recent_product_times and len(recent_product_times) > 0:
|
||||
# Use recent performance if available
|
||||
avg_time_per_product = sum(recent_product_times) / len(recent_product_times)
|
||||
else:
|
||||
# Fallback to default
|
||||
avg_time_per_product = 60.0 # 1 minute per product
|
||||
|
||||
# Estimate: remaining products * avg time + overhead
|
||||
estimated_remaining = (remaining_products * avg_time_per_product) + 60.0 # +1 min overhead
|
||||
|
||||
logger.debug(
|
||||
"Early stage estimation",
|
||||
progress=progress,
|
||||
remaining_products=remaining_products,
|
||||
avg_time_per_product=avg_time_per_product,
|
||||
estimated_remaining=estimated_remaining
|
||||
)
|
||||
|
||||
# Mid/late stage (21-99%): Use actual throughput
|
||||
else:
|
||||
if products_completed > 0:
|
||||
# Calculate actual time per product from current run
|
||||
actual_time_per_product = elapsed_time / products_completed
|
||||
remaining_products = total_products - products_completed
|
||||
estimated_remaining = remaining_products * actual_time_per_product
|
||||
|
||||
logger.debug(
|
||||
"Mid/late stage estimation",
|
||||
progress=progress,
|
||||
products_completed=products_completed,
|
||||
total_products=total_products,
|
||||
actual_time_per_product=actual_time_per_product,
|
||||
estimated_remaining=estimated_remaining
|
||||
)
|
||||
else:
|
||||
# Fallback to linear extrapolation
|
||||
estimated_total = (elapsed_time / progress) * 100
|
||||
estimated_remaining = estimated_total - elapsed_time
|
||||
|
||||
logger.debug(
|
||||
"Fallback linear estimation",
|
||||
progress=progress,
|
||||
elapsed_time=elapsed_time,
|
||||
estimated_remaining=estimated_remaining
|
||||
)
|
||||
|
||||
# Apply safety cap
|
||||
estimated_remaining = min(estimated_remaining, max_remaining_seconds)
|
||||
|
||||
return int(estimated_remaining)
|
||||
|
||||
|
||||
def calculate_average_product_time(
|
||||
products_completed: int,
|
||||
elapsed_time: float,
|
||||
min_products_threshold: int = 3
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Calculate average time per product from current job progress.
|
||||
|
||||
Args:
|
||||
products_completed: Number of products completed
|
||||
elapsed_time: Time elapsed since job start (seconds)
|
||||
min_products_threshold: Minimum products needed for reliable calculation
|
||||
|
||||
Returns:
|
||||
Average time per product in seconds, or None if insufficient data
|
||||
"""
|
||||
if products_completed < min_products_threshold:
|
||||
return None
|
||||
|
||||
avg_time = elapsed_time / products_completed
|
||||
|
||||
logger.debug(
|
||||
"Calculated average product time",
|
||||
products_completed=products_completed,
|
||||
elapsed_time=elapsed_time,
|
||||
avg_time=avg_time
|
||||
)
|
||||
|
||||
return avg_time
|
||||
|
||||
|
||||
def format_time_remaining(seconds: int) -> str:
|
||||
"""
|
||||
Format remaining time in human-readable format.
|
||||
|
||||
Args:
|
||||
seconds: Time in seconds
|
||||
|
||||
Returns:
|
||||
Formatted string (e.g., "5 minutes", "1 hour 23 minutes")
|
||||
|
||||
Examples:
|
||||
>>> format_time_remaining(45)
|
||||
"45 seconds"
|
||||
|
||||
>>> format_time_remaining(180)
|
||||
"3 minutes"
|
||||
|
||||
>>> format_time_remaining(5400)
|
||||
"1 hour 30 minutes"
|
||||
"""
|
||||
if seconds < 60:
|
||||
return f"{seconds} seconds"
|
||||
|
||||
minutes = seconds // 60
|
||||
remaining_seconds = seconds % 60
|
||||
|
||||
if minutes < 60:
|
||||
if remaining_seconds > 0:
|
||||
return f"{minutes} minutes {remaining_seconds} seconds"
|
||||
return f"{minutes} minutes"
|
||||
|
||||
hours = minutes // 60
|
||||
remaining_minutes = minutes % 60
|
||||
|
||||
if remaining_minutes > 0:
|
||||
return f"{hours} hour{'s' if hours > 1 else ''} {remaining_minutes} minutes"
|
||||
return f"{hours} hour{'s' if hours > 1 else ''}"
|
||||
|
||||
|
||||
def get_historical_average_estimate(
|
||||
db_session,
|
||||
tenant_id: str,
|
||||
lookback_days: int = 30,
|
||||
limit: int = 10
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get historical average training time per product for a tenant.
|
||||
|
||||
This function queries the TrainingPerformanceMetrics table to get
|
||||
recent historical data and calculate an average.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
tenant_id: Tenant UUID
|
||||
lookback_days: How many days back to look
|
||||
limit: Maximum number of historical records to consider
|
||||
|
||||
Returns:
|
||||
Average time per product in seconds, or None if no historical data
|
||||
"""
|
||||
try:
|
||||
from app.models.training import TrainingPerformanceMetrics
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=lookback_days)
|
||||
|
||||
# Query recent training performance metrics
|
||||
metrics = db_session.query(TrainingPerformanceMetrics).filter(
|
||||
TrainingPerformanceMetrics.tenant_id == tenant_id,
|
||||
TrainingPerformanceMetrics.completed_at >= cutoff
|
||||
).order_by(
|
||||
TrainingPerformanceMetrics.completed_at.desc()
|
||||
).limit(limit).all()
|
||||
|
||||
if not metrics:
|
||||
logger.info(
|
||||
"No historical training data found",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
return None
|
||||
|
||||
# Calculate weighted average (more recent = higher weight)
|
||||
total_weight = 0
|
||||
weighted_sum = 0
|
||||
|
||||
for i, metric in enumerate(metrics):
|
||||
# Weight: newer records get higher weight
|
||||
weight = limit - i
|
||||
weighted_sum += metric.avg_time_per_product * weight
|
||||
total_weight += weight
|
||||
|
||||
if total_weight == 0:
|
||||
return None
|
||||
|
||||
weighted_avg = weighted_sum / total_weight
|
||||
|
||||
logger.info(
|
||||
"Calculated historical average",
|
||||
tenant_id=tenant_id,
|
||||
records_used=len(metrics),
|
||||
weighted_avg=weighted_avg
|
||||
)
|
||||
|
||||
return weighted_avg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error getting historical average",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
Reference in New Issue
Block a user