Improve the frontend
This commit is contained in:
@@ -9,6 +9,7 @@ import structlog
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
import shared.redis_utils
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
@@ -39,6 +40,7 @@ from app.services.training_events import (
|
||||
publish_training_failed
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
|
||||
logger = structlog.get_logger()
|
||||
route_builder = RouteBuilder('training')
|
||||
@@ -86,7 +88,8 @@ async def start_training_job(
|
||||
request_obj: Request = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service),
|
||||
rate_limiter = Depends(get_rate_limiter)
|
||||
rate_limiter = Depends(get_rate_limiter),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Start a new training job for all tenant products (Admin+ only, quota enforced).
|
||||
@@ -169,9 +172,7 @@ async def start_training_job(
|
||||
# We don't know exact product count yet, so use historical average or estimate
|
||||
try:
|
||||
# Try to get historical average for this tenant
|
||||
from app.core.database import get_db
|
||||
db = next(get_db())
|
||||
historical_avg = get_historical_average_estimate(db, tenant_id)
|
||||
historical_avg = await get_historical_average_estimate(db, tenant_id)
|
||||
|
||||
# If no historical data, estimate based on typical product count (10-20 products)
|
||||
estimated_products = 15 # Conservative estimate
|
||||
|
||||
@@ -129,9 +129,7 @@ class EnhancedBakeryMLTrainer:
|
||||
|
||||
# Try to get historical average for more accurate estimates
|
||||
try:
|
||||
historical_avg = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
get_historical_average_estimate,
|
||||
historical_avg = await get_historical_average_estimate(
|
||||
db_session,
|
||||
tenant_id
|
||||
)
|
||||
|
||||
@@ -9,6 +9,8 @@ Provides intelligent time estimation for training jobs based on:
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -256,8 +258,8 @@ def format_time_remaining(seconds: int) -> str:
|
||||
return f"{hours} hour{'s' if hours > 1 else ''}"
|
||||
|
||||
|
||||
def get_historical_average_estimate(
|
||||
db_session,
|
||||
async def get_historical_average_estimate(
|
||||
db_session: AsyncSession,
|
||||
tenant_id: str,
|
||||
lookback_days: int = 30,
|
||||
limit: int = 10
|
||||
@@ -269,7 +271,7 @@ def get_historical_average_estimate(
|
||||
recent historical data and calculate an average.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
db_session: Async database session
|
||||
tenant_id: Tenant UUID
|
||||
lookback_days: How many days back to look
|
||||
limit: Maximum number of historical records to consider
|
||||
@@ -283,13 +285,19 @@ def get_historical_average_estimate(
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=lookback_days)
|
||||
|
||||
# Query recent training performance metrics
|
||||
metrics = db_session.query(TrainingPerformanceMetrics).filter(
|
||||
TrainingPerformanceMetrics.tenant_id == tenant_id,
|
||||
TrainingPerformanceMetrics.completed_at >= cutoff
|
||||
).order_by(
|
||||
TrainingPerformanceMetrics.completed_at.desc()
|
||||
).limit(limit).all()
|
||||
# Query recent training performance metrics using SQLAlchemy 2.0 async pattern
|
||||
query = (
|
||||
select(TrainingPerformanceMetrics)
|
||||
.where(
|
||||
TrainingPerformanceMetrics.tenant_id == tenant_id,
|
||||
TrainingPerformanceMetrics.completed_at >= cutoff
|
||||
)
|
||||
.order_by(TrainingPerformanceMetrics.completed_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
metrics = result.scalars().all()
|
||||
|
||||
if not metrics:
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user