Add role-based filtering and imporve code
This commit is contained in:
@@ -8,11 +8,19 @@ from typing import Optional, Dict, Any
|
||||
import structlog
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
import shared.redis_utils
|
||||
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role, admin_role_required
|
||||
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
|
||||
from shared.subscription.plans import (
|
||||
get_training_job_quota,
|
||||
get_dataset_size_limit
|
||||
)
|
||||
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from app.schemas.training import (
|
||||
@@ -20,6 +28,11 @@ from app.schemas.training import (
|
||||
SingleProductTrainingRequest,
|
||||
TrainingJobResponse
|
||||
)
|
||||
from app.utils.time_estimation import (
|
||||
calculate_initial_estimate,
|
||||
calculate_estimated_completion_time,
|
||||
get_historical_average_estimate
|
||||
)
|
||||
from app.services.training_events import (
|
||||
publish_training_started,
|
||||
publish_training_completed,
|
||||
@@ -32,6 +45,30 @@ route_builder = RouteBuilder('training')
|
||||
|
||||
router = APIRouter(tags=["training-operations"])
|
||||
|
||||
# Initialize audit logger
|
||||
audit_logger = create_audit_logger("training-service")
|
||||
|
||||
# Redis client for rate limiting
|
||||
_redis_client = None
|
||||
|
||||
async def get_training_redis_client():
|
||||
"""Get or create Redis client for rate limiting"""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
# Initialize Redis if not already done
|
||||
try:
|
||||
from app.core.config import settings
|
||||
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
except:
|
||||
# Fallback to getting the client directly (if already initialized elsewhere)
|
||||
_redis_client = await shared.redis_utils.get_redis_client()
|
||||
return _redis_client
|
||||
|
||||
async def get_rate_limiter():
|
||||
"""Dependency for rate limiter"""
|
||||
redis_client = await get_training_redis_client()
|
||||
return create_rate_limiter(redis_client)
|
||||
|
||||
def get_enhanced_training_service():
|
||||
"""Dependency injection for EnhancedTrainingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
@@ -40,31 +77,82 @@ def get_enhanced_training_service():
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
|
||||
async def start_training_job(
|
||||
request: TrainingJobRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service),
|
||||
rate_limiter = Depends(get_rate_limiter)
|
||||
):
|
||||
"""
|
||||
Start a new training job for all tenant products using repository pattern.
|
||||
Start a new training job for all tenant products (Admin+ only, quota enforced).
|
||||
|
||||
**RBAC:** Admin or Owner role required
|
||||
**Quotas:**
|
||||
- Starter: 1 training job/day, max 1,000 rows
|
||||
- Professional: 5 training jobs/day, max 10,000 rows
|
||||
- Enterprise: Unlimited jobs, unlimited rows
|
||||
|
||||
Enhanced immediate response pattern:
|
||||
1. Validate request with enhanced validation
|
||||
2. Create job record using repository pattern
|
||||
3. Return 200 with enhanced job details
|
||||
4. Execute enhanced training in background with repository tracking
|
||||
1. Validate subscription tier and quotas
|
||||
2. Validate request with enhanced validation
|
||||
3. Create job record using repository pattern
|
||||
4. Return 200 with enhanced job details
|
||||
5. Execute enhanced training in background with repository tracking
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
- Quota enforcement by subscription tier
|
||||
- Audit logging for all operations
|
||||
- Enhanced error handling and logging
|
||||
- Metrics tracking and monitoring
|
||||
- Transactional operations
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
# Get subscription tier and enforce quotas
|
||||
tier = current_user.get('subscription_tier', 'starter')
|
||||
|
||||
# Estimate dataset size (this should come from the request or be calculated)
|
||||
# For now, we'll assume a reasonable estimate
|
||||
estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500
|
||||
|
||||
# Initialize variables for later use
|
||||
quota_result = None
|
||||
quota_limit = None
|
||||
|
||||
try:
|
||||
# Validate dataset size limits
|
||||
await rate_limiter.validate_dataset_size(
|
||||
tenant_id, estimated_dataset_size, tier
|
||||
)
|
||||
|
||||
# Check daily training job quota
|
||||
quota_limit = get_training_job_quota(tier)
|
||||
quota_result = await rate_limiter.check_and_increment_quota(
|
||||
tenant_id,
|
||||
"training_jobs",
|
||||
quota_limit,
|
||||
period=86400 # 24 hours
|
||||
)
|
||||
|
||||
logger.info("Training job quota check passed",
|
||||
tenant_id=tenant_id,
|
||||
tier=tier,
|
||||
current_usage=quota_result.get('current', 0) if quota_result else 0,
|
||||
limit=quota_limit)
|
||||
|
||||
except HTTPException:
|
||||
# Quota or validation error - re-raise
|
||||
raise
|
||||
except Exception as quota_error:
|
||||
logger.error("Quota validation failed", error=str(quota_error))
|
||||
# Continue with job creation but log the error
|
||||
|
||||
try:
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
@@ -85,6 +173,25 @@ async def start_training_job(
|
||||
total_products=0 # Will be updated when actual training starts
|
||||
)
|
||||
|
||||
# Calculate intelligent time estimate
|
||||
# We don't know exact product count yet, so use historical average or estimate
|
||||
try:
|
||||
# Try to get historical average for this tenant
|
||||
from app.core.database import get_db
|
||||
db = next(get_db())
|
||||
historical_avg = get_historical_average_estimate(db, tenant_id)
|
||||
|
||||
# If no historical data, estimate based on typical product count (10-20 products)
|
||||
estimated_products = 15 # Conservative estimate
|
||||
estimated_duration_minutes = calculate_initial_estimate(
|
||||
total_products=estimated_products,
|
||||
avg_training_time_per_product=historical_avg if historical_avg else 60.0
|
||||
)
|
||||
except Exception as est_error:
|
||||
logger.warning("Could not calculate intelligent estimate, using default",
|
||||
error=str(est_error))
|
||||
estimated_duration_minutes = 15 # Default fallback
|
||||
|
||||
# Add enhanced background task
|
||||
background_tasks.add_task(
|
||||
execute_training_job_background,
|
||||
@@ -92,7 +199,8 @@ async def start_training_job(
|
||||
job_id=job_id,
|
||||
bakery_location=(40.4168, -3.7038),
|
||||
requested_start=request.start_date,
|
||||
requested_end=request.end_date
|
||||
requested_end=request.end_date,
|
||||
estimated_duration_minutes=estimated_duration_minutes
|
||||
)
|
||||
|
||||
# Return enhanced immediate success response
|
||||
@@ -102,7 +210,7 @@ async def start_training_job(
|
||||
"status": "pending",
|
||||
"message": "Enhanced training job started successfully using repository pattern",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": 18,
|
||||
"estimated_duration_minutes": estimated_duration_minutes,
|
||||
"training_results": {
|
||||
"total_products": 0,
|
||||
"successful_trainings": 0,
|
||||
@@ -126,6 +234,32 @@ async def start_training_job(
|
||||
job_id=job_id,
|
||||
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
|
||||
|
||||
# Log audit event for training job creation
|
||||
try:
|
||||
from app.core.database import get_db
|
||||
db = next(get_db())
|
||||
await audit_logger.log_event(
|
||||
db_session=db,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
action=AuditAction.CREATE.value,
|
||||
resource_type="training_job",
|
||||
resource_id=job_id,
|
||||
severity=AuditSeverity.MEDIUM.value,
|
||||
description=f"Started training job (tier: {tier})",
|
||||
metadata={
|
||||
"job_id": job_id,
|
||||
"tier": tier,
|
||||
"estimated_dataset_size": estimated_dataset_size,
|
||||
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
|
||||
"quota_limit": quota_limit if quota_limit else "unlimited"
|
||||
},
|
||||
endpoint="/jobs",
|
||||
method="POST"
|
||||
)
|
||||
except Exception as audit_error:
|
||||
logger.warning("Failed to log audit event", error=str(audit_error))
|
||||
|
||||
return TrainingJobResponse(**response_data)
|
||||
|
||||
except HTTPException:
|
||||
@@ -157,7 +291,8 @@ async def execute_training_job_background(
|
||||
job_id: str,
|
||||
bakery_location: tuple,
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None
|
||||
requested_end: Optional[datetime] = None,
|
||||
estimated_duration_minutes: int = 15
|
||||
):
|
||||
"""
|
||||
Enhanced background task that executes the training job using repository pattern.
|
||||
@@ -202,7 +337,7 @@ async def execute_training_job_background(
|
||||
},
|
||||
"requested_start": requested_start.isoformat() if requested_start else None,
|
||||
"requested_end": requested_end.isoformat() if requested_end else None,
|
||||
"estimated_duration_minutes": 18,
|
||||
"estimated_duration_minutes": estimated_duration_minutes,
|
||||
"background_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
@@ -278,16 +413,20 @@ async def execute_training_job_background(
|
||||
|
||||
@router.post(
|
||||
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
|
||||
async def start_single_product_training(
|
||||
request: SingleProductTrainingRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
inventory_product_id: str = Path(..., description="Inventory product UUID"),
|
||||
request_obj: Request = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Start enhanced training for a single product using repository pattern.
|
||||
Start enhanced training for a single product (Admin+ only).
|
||||
|
||||
**RBAC:** Admin or Owner role required
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
|
||||
Reference in New Issue
Block a user