Add base kubernetes support final fix 4

This commit is contained in:
Urtzi Alfaro
2025-09-29 07:54:25 +02:00
parent 57f77638cc
commit 4777e59e7a
14 changed files with 1041 additions and 167 deletions

View File

@@ -131,15 +131,28 @@ class OnboardingService:
# Update the step
await self._update_user_onboarding_data(
user_id,
step_name,
user_id,
step_name,
{
"completed": update_request.completed,
"completed_at": datetime.now(timezone.utc).isoformat() if update_request.completed else None,
"data": update_request.data or {}
}
)
# Try to update summary and handle partial failures gracefully
try:
# Update the user's onboarding summary
await self._update_user_summary(user_id)
except HTTPException as he:
# If it's a 207 Multi-Status (partial success), log warning but continue
if he.status_code == status.HTTP_207_MULTI_STATUS:
logger.warning(f"Summary update failed for user {user_id}, step {step_name}: {he.detail}")
# Continue execution - the step update was successful
else:
# Re-raise other HTTP exceptions
raise
# Return updated progress
return await self.get_user_progress(user_id)
@@ -284,10 +297,7 @@ class OnboardingService:
completed=completed,
step_data=data_payload
)
# Update the user's onboarding summary
await self._update_user_summary(user_id)
logger.info(f"Successfully updated onboarding step for user {user_id}: {step_name} = {step_data}")
return updated_step
@@ -300,26 +310,26 @@ class OnboardingService:
try:
# Get updated progress
user_progress_data = await self._get_user_onboarding_data(user_id)
# Calculate current status
completed_steps = []
for step_name in ONBOARDING_STEPS:
if user_progress_data.get(step_name, {}).get("completed", False):
completed_steps.append(step_name)
# Determine current and next step
current_step = self._get_current_step(completed_steps)
next_step = self._get_next_step(completed_steps)
# Calculate completion percentage
completion_percentage = (len(completed_steps) / len(ONBOARDING_STEPS)) * 100
# Check if fully completed
fully_completed = len(completed_steps) == len(ONBOARDING_STEPS)
# Format steps count
steps_completed_count = f"{len(completed_steps)}/{len(ONBOARDING_STEPS)}"
# Update summary in database
await self.onboarding_repo.upsert_user_summary(
user_id=user_id,
@@ -329,10 +339,18 @@ class OnboardingService:
fully_completed=fully_completed,
steps_completed_count=steps_completed_count
)
logger.debug(f"Successfully updated onboarding summary for user {user_id}")
except Exception as e:
logger.error(f"Error updating onboarding summary for user {user_id}: {e}")
# Don't raise here - summary update failure shouldn't break step updates
logger.error(f"Error updating onboarding summary for user {user_id}: {e}",
extra={"user_id": user_id, "error_type": type(e).__name__})
# Raise a warning-level HTTPException to inform frontend without breaking the flow
# This allows the step update to succeed while alerting about summary issues
raise HTTPException(
status_code=status.HTTP_207_MULTI_STATUS,
detail=f"Step updated successfully, but summary update failed: {str(e)}"
)
# API Routes

View File

@@ -61,11 +61,11 @@ class UserOnboardingSummary(Base):
# Summary fields
current_step = Column(String(50), nullable=False, default="user_registered")
next_step = Column(String(50))
completion_percentage = Column(String(10), default="0.0") # Store as string for precision
completion_percentage = Column(String(50), default="0.0") # Store as string for precision
fully_completed = Column(Boolean, default=False)
# Progress tracking
steps_completed_count = Column(String(10), default="0") # Store as string: "3/5"
steps_completed_count = Column(String(50), default="0") # Store as string: "3/5"
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))

View File

@@ -82,11 +82,45 @@ async def training_progress_websocket(
tenant_id: str,
job_id: str
):
connection_id = f"{tenant_id}_{id(websocket)}"
# Validate token from query parameters
token = websocket.query_params.get("token")
if not token:
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
await websocket.close(code=1008, reason="Authentication token required")
return
# Validate the token (use the same JWT handler as gateway)
from shared.auth.jwt_handler import JWTHandler
from app.core.config import settings
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload:
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
await websocket.close(code=1008, reason="Invalid authentication token")
return
# Verify user has access to this tenant
user_id = payload.get('user_id')
if not user_id:
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
await websocket.close(code=1008, reason="Invalid token payload")
return
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
except Exception as e:
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
await websocket.close(code=1008, reason="Token validation failed")
return
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
await connection_manager.connect(websocket, job_id, connection_id)
logger.info(f"WebSocket connection established for job {job_id}")
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
consumer_task = None
training_completed = False
@@ -100,11 +134,12 @@ async def training_progress_websocket(
while not training_completed:
try:
# FIXED: Use receive() instead of receive_text()
# Coordinate with frontend 30s heartbeat + gateway 45s timeout
# This should be longer than gateway timeout to avoid premature closure
try:
data = await asyncio.wait_for(websocket.receive(), timeout=30.0)
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
last_activity = asyncio.get_event_loop().time()
# Handle different message types
if data["type"] == "websocket.receive":
if "text" in data:
@@ -123,31 +158,41 @@ async def training_progress_websocket(
elif message_text == "close":
logger.info(f"Client requested connection close for job {job_id}")
break
elif "bytes" in data:
# Handle binary messages (WebSocket ping frames)
# Handle binary messages (WebSocket ping frames) - respond with text pong for compatibility
await websocket.send_text("pong")
logger.debug(f"Binary ping received for job {job_id}")
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
elif data["type"] == "websocket.disconnect":
logger.info(f"WebSocket disconnect message received for job {job_id}")
break
except asyncio.TimeoutError:
# No message received in 30 seconds - send heartbeat
# No message received in 60 seconds - this is now coordinated with gateway timeouts
current_time = asyncio.get_event_loop().time()
if current_time - last_activity > 60:
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat")
try:
await websocket.send_json({
"type": "heartbeat",
"job_id": job_id,
"timestamp": str(datetime.datetime.now())
})
except Exception as e:
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
break
# Send heartbeat only if we haven't received frontend ping for too long
# Frontend sends ping every 30s, so 60s timeout + 30s grace = 90s before heartbeat
if current_time - last_activity > 90: # 90 seconds of total inactivity
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
try:
await websocket.send_json({
"type": "heartbeat",
"job_id": job_id,
"timestamp": str(datetime.datetime.now()),
"message": "Training service heartbeat - frontend inactive",
"inactivity_seconds": int(current_time - last_activity)
})
last_activity = current_time
except Exception as e:
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
break
else:
# Normal timeout, frontend should be sending ping every 30s
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
continue
except WebSocketDisconnect:
logger.info(f"WebSocket client disconnected for job {job_id}")

View File

@@ -43,11 +43,71 @@ async def get_db_health() -> bool:
await conn.execute(text("SELECT 1"))
logger.debug("Database health check passed")
return True
except Exception as e:
logger.error("Database health check failed", error=str(e))
return False
async def get_comprehensive_db_health() -> dict:
"""
Comprehensive health check that verifies both connectivity and table existence
"""
health_status = {
"status": "healthy",
"connectivity": False,
"tables_exist": False,
"tables_verified": [],
"missing_tables": [],
"errors": []
}
try:
# Test basic connectivity
health_status["connectivity"] = await get_db_health()
if not health_status["connectivity"]:
health_status["status"] = "unhealthy"
health_status["errors"].append("Database connectivity failed")
return health_status
# Test table existence
tables_verified = await _verify_tables_exist()
health_status["tables_exist"] = tables_verified
if tables_verified:
health_status["tables_verified"] = [
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
]
else:
health_status["status"] = "unhealthy"
health_status["errors"].append("Required tables missing or inaccessible")
# Try to identify which specific tables are missing
try:
async with database_manager.get_session() as session:
for table_name in ['model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts']:
try:
await session.execute(text(f"SELECT 1 FROM {table_name} LIMIT 1"))
health_status["tables_verified"].append(table_name)
except Exception:
health_status["missing_tables"].append(table_name)
except Exception as e:
health_status["errors"].append(f"Error checking individual tables: {str(e)}")
logger.debug("Comprehensive database health check completed",
status=health_status["status"],
connectivity=health_status["connectivity"],
tables_exist=health_status["tables_exist"])
except Exception as e:
health_status["status"] = "unhealthy"
health_status["errors"].append(f"Health check failed: {str(e)}")
logger.error("Comprehensive database health check failed", error=str(e))
return health_status
# Training service specific database utilities
class TrainingDatabaseUtils:
"""Training service specific database utilities"""
@@ -223,27 +283,118 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
# Database initialization for training service
async def initialize_training_database():
"""Initialize database tables for training service"""
"""Initialize database tables for training service with retry logic and verification"""
import asyncio
from sqlalchemy import text
max_retries = 5
retry_delay = 2.0
for attempt in range(1, max_retries + 1):
try:
logger.info("Initializing training service database",
attempt=attempt, max_retries=max_retries)
# Step 1: Test database connectivity first
logger.info("Testing database connectivity...")
connection_ok = await database_manager.test_connection()
if not connection_ok:
raise Exception("Database connection test failed")
logger.info("Database connectivity verified")
# Step 2: Import models to ensure they're registered
logger.info("Importing and registering database models...")
from app.models.training import (
ModelTrainingLog,
TrainedModel,
ModelPerformanceMetric,
TrainingJobQueue,
ModelArtifact
)
# Verify models are registered in metadata
expected_tables = {
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
}
registered_tables = set(Base.metadata.tables.keys())
missing_tables = expected_tables - registered_tables
if missing_tables:
raise Exception(f"Models not properly registered: {missing_tables}")
logger.info("Models registered successfully",
tables=list(registered_tables))
# Step 3: Create tables using shared infrastructure with verification
logger.info("Creating database tables...")
await database_manager.create_tables()
# Step 4: Verify tables were actually created
logger.info("Verifying table creation...")
verification_successful = await _verify_tables_exist()
if not verification_successful:
raise Exception("Table verification failed - tables were not created properly")
logger.info("Training service database initialized and verified successfully",
attempt=attempt)
return
except Exception as e:
logger.error("Database initialization failed",
attempt=attempt,
max_retries=max_retries,
error=str(e))
if attempt == max_retries:
logger.error("All database initialization attempts failed - giving up")
raise Exception(f"Failed to initialize training database after {max_retries} attempts: {str(e)}")
# Wait before retry with exponential backoff
wait_time = retry_delay * (2 ** (attempt - 1))
logger.info("Retrying database initialization",
retry_in_seconds=wait_time,
next_attempt=attempt + 1)
await asyncio.sleep(wait_time)
async def _verify_tables_exist() -> bool:
"""Verify that all required tables exist in the database"""
try:
logger.info("Initializing training service database")
# Import models to ensure they're registered
from app.models.training import (
ModelTrainingLog,
TrainedModel,
ModelPerformanceMetric,
TrainingJobQueue,
ModelArtifact
)
# Create tables using shared infrastructure
await database_manager.create_tables()
logger.info("Training service database initialized successfully")
async with database_manager.get_session() as session:
# Check each required table exists and is accessible
required_tables = [
'model_training_logs',
'trained_models',
'model_performance_metrics',
'training_job_queue',
'model_artifacts'
]
for table_name in required_tables:
try:
# Try to query the table structure
result = await session.execute(
text(f"SELECT 1 FROM {table_name} LIMIT 1")
)
logger.debug(f"Table {table_name} exists and is accessible")
except Exception as table_error:
# If it's a "relation does not exist" error, table creation failed
if "does not exist" in str(table_error).lower():
logger.error(f"Table {table_name} does not exist", error=str(table_error))
return False
# If it's an empty table, that's fine - table exists
elif "no data" in str(table_error).lower():
logger.debug(f"Table {table_name} exists but is empty (normal)")
else:
logger.warning(f"Unexpected error querying {table_name}", error=str(table_error))
logger.info("All required tables verified successfully",
tables=required_tables)
return True
except Exception as e:
logger.error("Failed to initialize training service database", error=str(e))
raise
logger.error("Table verification failed", error=str(e))
return False
# Database cleanup for training service
async def cleanup_training_database():

View File

@@ -16,7 +16,7 @@ from fastapi.responses import JSONResponse
import uvicorn
from app.core.config import settings
from app.core.database import initialize_training_database, cleanup_training_database, get_db_health
from app.core.database import initialize_training_database, cleanup_training_database, get_db_health, get_comprehensive_db_health
from app.api import training, models
from app.api.websocket import websocket_router
@@ -195,18 +195,69 @@ async def health_check():
@app.get("/health/ready")
async def readiness_check():
"""Kubernetes readiness probe endpoint"""
checks = {
"database": await get_db_health(),
"application": getattr(app.state, 'ready', False)
}
if all(checks.values()):
return {"status": "ready", "checks": checks}
else:
"""Kubernetes readiness probe endpoint with comprehensive database checks"""
try:
# Get comprehensive database health including table verification
db_health = await get_comprehensive_db_health()
checks = {
"database_connectivity": db_health["connectivity"],
"database_tables": db_health["tables_exist"],
"application": getattr(app.state, 'ready', False)
}
# Include detailed database info for debugging
database_details = {
"status": db_health["status"],
"tables_verified": db_health["tables_verified"],
"missing_tables": db_health["missing_tables"],
"errors": db_health["errors"]
}
# Service is ready only if all checks pass
all_ready = all(checks.values()) and db_health["status"] == "healthy"
if all_ready:
return {
"status": "ready",
"checks": checks,
"database": database_details
}
else:
return JSONResponse(
status_code=503,
content={
"status": "not ready",
"checks": checks,
"database": database_details
}
)
except Exception as e:
logger.error("Readiness check failed", error=str(e))
return JSONResponse(
status_code=503,
content={"status": "not ready", "checks": checks}
content={
"status": "not ready",
"error": f"Health check failed: {str(e)}"
}
)
@app.get("/health/database")
async def database_health_check():
"""Detailed database health endpoint for debugging"""
try:
db_health = await get_comprehensive_db_health()
status_code = 200 if db_health["status"] == "healthy" else 503
return JSONResponse(status_code=status_code, content=db_health)
except Exception as e:
logger.error("Database health check failed", error=str(e))
return JSONResponse(
status_code=503,
content={
"status": "unhealthy",
"error": f"Health check failed: {str(e)}"
}
)
@app.get("/metrics")
@@ -220,11 +271,6 @@ async def get_metrics():
async def liveness_check():
return {"status": "alive"}
@app.get("/health/ready")
async def readiness_check():
ready = getattr(app.state, 'ready', True)
return {"status": "ready" if ready else "not ready"}
@app.get("/")
async def root():
return {"service": "training-service", "version": "1.0.0"}