Add role-based filtering and imporve code

This commit is contained in:
Urtzi Alfaro
2025-10-15 16:12:49 +02:00
parent 96ad5c6692
commit 8f9e9a7edc
158 changed files with 11033 additions and 1544 deletions

View File

@@ -8,11 +8,19 @@ from typing import Optional, Dict, Any
import structlog
from datetime import datetime, timezone
import uuid
import shared.redis_utils
from shared.routing import RouteBuilder
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from shared.database.base import create_database_manager
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import require_user_role, admin_role_required
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
from shared.subscription.plans import (
get_training_job_quota,
get_dataset_size_limit
)
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
@@ -20,6 +28,11 @@ from app.schemas.training import (
SingleProductTrainingRequest,
TrainingJobResponse
)
from app.utils.time_estimation import (
calculate_initial_estimate,
calculate_estimated_completion_time,
get_historical_average_estimate
)
from app.services.training_events import (
publish_training_started,
publish_training_completed,
@@ -32,6 +45,30 @@ route_builder = RouteBuilder('training')
router = APIRouter(tags=["training-operations"])
# Initialize audit logger
audit_logger = create_audit_logger("training-service")
# Redis client for rate limiting
_redis_client = None
async def get_training_redis_client():
"""Get or create Redis client for rate limiting"""
global _redis_client
if _redis_client is None:
# Initialize Redis if not already done
try:
from app.core.config import settings
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
except:
# Fallback to getting the client directly (if already initialized elsewhere)
_redis_client = await shared.redis_utils.get_redis_client()
return _redis_client
async def get_rate_limiter():
"""Dependency for rate limiter"""
redis_client = await get_training_redis_client()
return create_rate_limiter(redis_client)
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
@@ -40,31 +77,82 @@ def get_enhanced_training_service():
@router.post(
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
@require_user_role(['admin', 'owner'])
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
async def start_training_job(
request: TrainingJobRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
current_user: Dict[str, Any] = Depends(get_current_user_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service),
rate_limiter = Depends(get_rate_limiter)
):
"""
Start a new training job for all tenant products using repository pattern.
Start a new training job for all tenant products (Admin+ only, quota enforced).
**RBAC:** Admin or Owner role required
**Quotas:**
- Starter: 1 training job/day, max 1,000 rows
- Professional: 5 training jobs/day, max 10,000 rows
- Enterprise: Unlimited jobs, unlimited rows
Enhanced immediate response pattern:
1. Validate request with enhanced validation
2. Create job record using repository pattern
3. Return 200 with enhanced job details
4. Execute enhanced training in background with repository tracking
1. Validate subscription tier and quotas
2. Validate request with enhanced validation
3. Create job record using repository pattern
4. Return 200 with enhanced job details
5. Execute enhanced training in background with repository tracking
Enhanced features:
- Repository pattern for data access
- Quota enforcement by subscription tier
- Audit logging for all operations
- Enhanced error handling and logging
- Metrics tracking and monitoring
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
# Get subscription tier and enforce quotas
tier = current_user.get('subscription_tier', 'starter')
# Estimate dataset size (this should come from the request or be calculated)
# For now, we'll assume a reasonable estimate
estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500
# Initialize variables for later use
quota_result = None
quota_limit = None
try:
# Validate dataset size limits
await rate_limiter.validate_dataset_size(
tenant_id, estimated_dataset_size, tier
)
# Check daily training job quota
quota_limit = get_training_job_quota(tier)
quota_result = await rate_limiter.check_and_increment_quota(
tenant_id,
"training_jobs",
quota_limit,
period=86400 # 24 hours
)
logger.info("Training job quota check passed",
tenant_id=tenant_id,
tier=tier,
current_usage=quota_result.get('current', 0) if quota_result else 0,
limit=quota_limit)
except HTTPException:
# Quota or validation error - re-raise
raise
except Exception as quota_error:
logger.error("Quota validation failed", error=str(quota_error))
# Continue with job creation but log the error
try:
# Generate enhanced job ID
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
@@ -85,6 +173,25 @@ async def start_training_job(
total_products=0 # Will be updated when actual training starts
)
# Calculate intelligent time estimate
# We don't know exact product count yet, so use historical average or estimate
try:
# Try to get historical average for this tenant
from app.core.database import get_db
db = next(get_db())
historical_avg = get_historical_average_estimate(db, tenant_id)
# If no historical data, estimate based on typical product count (10-20 products)
estimated_products = 15 # Conservative estimate
estimated_duration_minutes = calculate_initial_estimate(
total_products=estimated_products,
avg_training_time_per_product=historical_avg if historical_avg else 60.0
)
except Exception as est_error:
logger.warning("Could not calculate intelligent estimate, using default",
error=str(est_error))
estimated_duration_minutes = 15 # Default fallback
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
@@ -92,7 +199,8 @@ async def start_training_job(
job_id=job_id,
bakery_location=(40.4168, -3.7038),
requested_start=request.start_date,
requested_end=request.end_date
requested_end=request.end_date,
estimated_duration_minutes=estimated_duration_minutes
)
# Return enhanced immediate success response
@@ -102,7 +210,7 @@ async def start_training_job(
"status": "pending",
"message": "Enhanced training job started successfully using repository pattern",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 18,
"estimated_duration_minutes": estimated_duration_minutes,
"training_results": {
"total_products": 0,
"successful_trainings": 0,
@@ -126,6 +234,32 @@ async def start_training_job(
job_id=job_id,
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
# Log audit event for training job creation
try:
from app.core.database import get_db
db = next(get_db())
await audit_logger.log_event(
db_session=db,
tenant_id=tenant_id,
user_id=current_user["user_id"],
action=AuditAction.CREATE.value,
resource_type="training_job",
resource_id=job_id,
severity=AuditSeverity.MEDIUM.value,
description=f"Started training job (tier: {tier})",
metadata={
"job_id": job_id,
"tier": tier,
"estimated_dataset_size": estimated_dataset_size,
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
"quota_limit": quota_limit if quota_limit else "unlimited"
},
endpoint="/jobs",
method="POST"
)
except Exception as audit_error:
logger.warning("Failed to log audit event", error=str(audit_error))
return TrainingJobResponse(**response_data)
except HTTPException:
@@ -157,7 +291,8 @@ async def execute_training_job_background(
job_id: str,
bakery_location: tuple,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None
requested_end: Optional[datetime] = None,
estimated_duration_minutes: int = 15
):
"""
Enhanced background task that executes the training job using repository pattern.
@@ -202,7 +337,7 @@ async def execute_training_job_background(
},
"requested_start": requested_start.isoformat() if requested_start else None,
"requested_end": requested_end.isoformat() if requested_end else None,
"estimated_duration_minutes": 18,
"estimated_duration_minutes": estimated_duration_minutes,
"background_execution": True,
"enhanced_features": True,
"repository_pattern": True,
@@ -278,16 +413,20 @@ async def execute_training_job_background(
@router.post(
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
@require_user_role(['admin', 'owner'])
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
async def start_single_product_training(
request: SingleProductTrainingRequest,
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: str = Path(..., description="Inventory product UUID"),
request_obj: Request = None,
current_user: Dict[str, Any] = Depends(get_current_user_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start enhanced training for a single product using repository pattern.
Start enhanced training for a single product (Admin+ only).
**RBAC:** Admin or Owner role required
Enhanced features:
- Repository pattern for data access