Add base kubernetes support final fix 4
This commit is contained in:
@@ -82,11 +82,45 @@ async def training_progress_websocket(
|
||||
tenant_id: str,
|
||||
job_id: str
|
||||
):
|
||||
connection_id = f"{tenant_id}_{id(websocket)}"
|
||||
|
||||
# Validate token from query parameters
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
# Validate the token (use the same JWT handler as gateway)
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid authentication token")
|
||||
return
|
||||
|
||||
# Verify user has access to this tenant
|
||||
user_id = payload.get('user_id')
|
||||
if not user_id:
|
||||
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid token payload")
|
||||
return
|
||||
|
||||
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
|
||||
await websocket.close(code=1008, reason="Token validation failed")
|
||||
return
|
||||
|
||||
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
|
||||
|
||||
await connection_manager.connect(websocket, job_id, connection_id)
|
||||
logger.info(f"WebSocket connection established for job {job_id}")
|
||||
|
||||
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
|
||||
|
||||
consumer_task = None
|
||||
training_completed = False
|
||||
|
||||
@@ -100,11 +134,12 @@ async def training_progress_websocket(
|
||||
|
||||
while not training_completed:
|
||||
try:
|
||||
# FIXED: Use receive() instead of receive_text()
|
||||
# Coordinate with frontend 30s heartbeat + gateway 45s timeout
|
||||
# This should be longer than gateway timeout to avoid premature closure
|
||||
try:
|
||||
data = await asyncio.wait_for(websocket.receive(), timeout=30.0)
|
||||
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
|
||||
# Handle different message types
|
||||
if data["type"] == "websocket.receive":
|
||||
if "text" in data:
|
||||
@@ -123,31 +158,41 @@ async def training_progress_websocket(
|
||||
elif message_text == "close":
|
||||
logger.info(f"Client requested connection close for job {job_id}")
|
||||
break
|
||||
|
||||
|
||||
elif "bytes" in data:
|
||||
# Handle binary messages (WebSocket ping frames)
|
||||
# Handle binary messages (WebSocket ping frames) - respond with text pong for compatibility
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Binary ping received for job {job_id}")
|
||||
|
||||
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
|
||||
|
||||
elif data["type"] == "websocket.disconnect":
|
||||
logger.info(f"WebSocket disconnect message received for job {job_id}")
|
||||
break
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No message received in 30 seconds - send heartbeat
|
||||
# No message received in 60 seconds - this is now coordinated with gateway timeouts
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
if current_time - last_activity > 60:
|
||||
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat")
|
||||
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "heartbeat",
|
||||
"job_id": job_id,
|
||||
"timestamp": str(datetime.datetime.now())
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
|
||||
break
|
||||
|
||||
# Send heartbeat only if we haven't received frontend ping for too long
|
||||
# Frontend sends ping every 30s, so 60s timeout + 30s grace = 90s before heartbeat
|
||||
if current_time - last_activity > 90: # 90 seconds of total inactivity
|
||||
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
|
||||
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "heartbeat",
|
||||
"job_id": job_id,
|
||||
"timestamp": str(datetime.datetime.now()),
|
||||
"message": "Training service heartbeat - frontend inactive",
|
||||
"inactivity_seconds": int(current_time - last_activity)
|
||||
})
|
||||
last_activity = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
|
||||
break
|
||||
else:
|
||||
# Normal timeout, frontend should be sending ping every 30s
|
||||
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket client disconnected for job {job_id}")
|
||||
|
||||
@@ -43,11 +43,71 @@ async def get_db_health() -> bool:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
logger.debug("Database health check passed")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return False
|
||||
|
||||
async def get_comprehensive_db_health() -> dict:
|
||||
"""
|
||||
Comprehensive health check that verifies both connectivity and table existence
|
||||
"""
|
||||
health_status = {
|
||||
"status": "healthy",
|
||||
"connectivity": False,
|
||||
"tables_exist": False,
|
||||
"tables_verified": [],
|
||||
"missing_tables": [],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
try:
|
||||
# Test basic connectivity
|
||||
health_status["connectivity"] = await get_db_health()
|
||||
|
||||
if not health_status["connectivity"]:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append("Database connectivity failed")
|
||||
return health_status
|
||||
|
||||
# Test table existence
|
||||
tables_verified = await _verify_tables_exist()
|
||||
health_status["tables_exist"] = tables_verified
|
||||
|
||||
if tables_verified:
|
||||
health_status["tables_verified"] = [
|
||||
'model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts'
|
||||
]
|
||||
else:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append("Required tables missing or inaccessible")
|
||||
|
||||
# Try to identify which specific tables are missing
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
for table_name in ['model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts']:
|
||||
try:
|
||||
await session.execute(text(f"SELECT 1 FROM {table_name} LIMIT 1"))
|
||||
health_status["tables_verified"].append(table_name)
|
||||
except Exception:
|
||||
health_status["missing_tables"].append(table_name)
|
||||
except Exception as e:
|
||||
health_status["errors"].append(f"Error checking individual tables: {str(e)}")
|
||||
|
||||
logger.debug("Comprehensive database health check completed",
|
||||
status=health_status["status"],
|
||||
connectivity=health_status["connectivity"],
|
||||
tables_exist=health_status["tables_exist"])
|
||||
|
||||
except Exception as e:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append(f"Health check failed: {str(e)}")
|
||||
logger.error("Comprehensive database health check failed", error=str(e))
|
||||
|
||||
return health_status
|
||||
|
||||
# Training service specific database utilities
|
||||
class TrainingDatabaseUtils:
|
||||
"""Training service specific database utilities"""
|
||||
@@ -223,27 +283,118 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
# Database initialization for training service
|
||||
async def initialize_training_database():
|
||||
"""Initialize database tables for training service"""
|
||||
"""Initialize database tables for training service with retry logic and verification"""
|
||||
import asyncio
|
||||
from sqlalchemy import text
|
||||
|
||||
max_retries = 5
|
||||
retry_delay = 2.0
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
logger.info("Initializing training service database",
|
||||
attempt=attempt, max_retries=max_retries)
|
||||
|
||||
# Step 1: Test database connectivity first
|
||||
logger.info("Testing database connectivity...")
|
||||
connection_ok = await database_manager.test_connection()
|
||||
if not connection_ok:
|
||||
raise Exception("Database connection test failed")
|
||||
logger.info("Database connectivity verified")
|
||||
|
||||
# Step 2: Import models to ensure they're registered
|
||||
logger.info("Importing and registering database models...")
|
||||
from app.models.training import (
|
||||
ModelTrainingLog,
|
||||
TrainedModel,
|
||||
ModelPerformanceMetric,
|
||||
TrainingJobQueue,
|
||||
ModelArtifact
|
||||
)
|
||||
|
||||
# Verify models are registered in metadata
|
||||
expected_tables = {
|
||||
'model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts'
|
||||
}
|
||||
registered_tables = set(Base.metadata.tables.keys())
|
||||
missing_tables = expected_tables - registered_tables
|
||||
if missing_tables:
|
||||
raise Exception(f"Models not properly registered: {missing_tables}")
|
||||
|
||||
logger.info("Models registered successfully",
|
||||
tables=list(registered_tables))
|
||||
|
||||
# Step 3: Create tables using shared infrastructure with verification
|
||||
logger.info("Creating database tables...")
|
||||
await database_manager.create_tables()
|
||||
|
||||
# Step 4: Verify tables were actually created
|
||||
logger.info("Verifying table creation...")
|
||||
verification_successful = await _verify_tables_exist()
|
||||
|
||||
if not verification_successful:
|
||||
raise Exception("Table verification failed - tables were not created properly")
|
||||
|
||||
logger.info("Training service database initialized and verified successfully",
|
||||
attempt=attempt)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database initialization failed",
|
||||
attempt=attempt,
|
||||
max_retries=max_retries,
|
||||
error=str(e))
|
||||
|
||||
if attempt == max_retries:
|
||||
logger.error("All database initialization attempts failed - giving up")
|
||||
raise Exception(f"Failed to initialize training database after {max_retries} attempts: {str(e)}")
|
||||
|
||||
# Wait before retry with exponential backoff
|
||||
wait_time = retry_delay * (2 ** (attempt - 1))
|
||||
logger.info("Retrying database initialization",
|
||||
retry_in_seconds=wait_time,
|
||||
next_attempt=attempt + 1)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
async def _verify_tables_exist() -> bool:
|
||||
"""Verify that all required tables exist in the database"""
|
||||
try:
|
||||
logger.info("Initializing training service database")
|
||||
|
||||
# Import models to ensure they're registered
|
||||
from app.models.training import (
|
||||
ModelTrainingLog,
|
||||
TrainedModel,
|
||||
ModelPerformanceMetric,
|
||||
TrainingJobQueue,
|
||||
ModelArtifact
|
||||
)
|
||||
|
||||
# Create tables using shared infrastructure
|
||||
await database_manager.create_tables()
|
||||
|
||||
logger.info("Training service database initialized successfully")
|
||||
|
||||
async with database_manager.get_session() as session:
|
||||
# Check each required table exists and is accessible
|
||||
required_tables = [
|
||||
'model_training_logs',
|
||||
'trained_models',
|
||||
'model_performance_metrics',
|
||||
'training_job_queue',
|
||||
'model_artifacts'
|
||||
]
|
||||
|
||||
for table_name in required_tables:
|
||||
try:
|
||||
# Try to query the table structure
|
||||
result = await session.execute(
|
||||
text(f"SELECT 1 FROM {table_name} LIMIT 1")
|
||||
)
|
||||
logger.debug(f"Table {table_name} exists and is accessible")
|
||||
except Exception as table_error:
|
||||
# If it's a "relation does not exist" error, table creation failed
|
||||
if "does not exist" in str(table_error).lower():
|
||||
logger.error(f"Table {table_name} does not exist", error=str(table_error))
|
||||
return False
|
||||
# If it's an empty table, that's fine - table exists
|
||||
elif "no data" in str(table_error).lower():
|
||||
logger.debug(f"Table {table_name} exists but is empty (normal)")
|
||||
else:
|
||||
logger.warning(f"Unexpected error querying {table_name}", error=str(table_error))
|
||||
|
||||
logger.info("All required tables verified successfully",
|
||||
tables=required_tables)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize training service database", error=str(e))
|
||||
raise
|
||||
logger.error("Table verification failed", error=str(e))
|
||||
return False
|
||||
|
||||
# Database cleanup for training service
|
||||
async def cleanup_training_database():
|
||||
|
||||
@@ -16,7 +16,7 @@ from fastapi.responses import JSONResponse
|
||||
import uvicorn
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import initialize_training_database, cleanup_training_database, get_db_health
|
||||
from app.core.database import initialize_training_database, cleanup_training_database, get_db_health, get_comprehensive_db_health
|
||||
from app.api import training, models
|
||||
|
||||
from app.api.websocket import websocket_router
|
||||
@@ -195,18 +195,69 @@ async def health_check():
|
||||
|
||||
@app.get("/health/ready")
|
||||
async def readiness_check():
|
||||
"""Kubernetes readiness probe endpoint"""
|
||||
checks = {
|
||||
"database": await get_db_health(),
|
||||
"application": getattr(app.state, 'ready', False)
|
||||
}
|
||||
|
||||
if all(checks.values()):
|
||||
return {"status": "ready", "checks": checks}
|
||||
else:
|
||||
"""Kubernetes readiness probe endpoint with comprehensive database checks"""
|
||||
try:
|
||||
# Get comprehensive database health including table verification
|
||||
db_health = await get_comprehensive_db_health()
|
||||
|
||||
checks = {
|
||||
"database_connectivity": db_health["connectivity"],
|
||||
"database_tables": db_health["tables_exist"],
|
||||
"application": getattr(app.state, 'ready', False)
|
||||
}
|
||||
|
||||
# Include detailed database info for debugging
|
||||
database_details = {
|
||||
"status": db_health["status"],
|
||||
"tables_verified": db_health["tables_verified"],
|
||||
"missing_tables": db_health["missing_tables"],
|
||||
"errors": db_health["errors"]
|
||||
}
|
||||
|
||||
# Service is ready only if all checks pass
|
||||
all_ready = all(checks.values()) and db_health["status"] == "healthy"
|
||||
|
||||
if all_ready:
|
||||
return {
|
||||
"status": "ready",
|
||||
"checks": checks,
|
||||
"database": database_details
|
||||
}
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"status": "not ready",
|
||||
"checks": checks,
|
||||
"database": database_details
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Readiness check failed", error=str(e))
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"status": "not ready", "checks": checks}
|
||||
content={
|
||||
"status": "not ready",
|
||||
"error": f"Health check failed: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
@app.get("/health/database")
|
||||
async def database_health_check():
|
||||
"""Detailed database health endpoint for debugging"""
|
||||
try:
|
||||
db_health = await get_comprehensive_db_health()
|
||||
status_code = 200 if db_health["status"] == "healthy" else 503
|
||||
return JSONResponse(status_code=status_code, content=db_health)
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"status": "unhealthy",
|
||||
"error": f"Health check failed: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
@app.get("/metrics")
|
||||
@@ -220,11 +271,6 @@ async def get_metrics():
|
||||
async def liveness_check():
|
||||
return {"status": "alive"}
|
||||
|
||||
@app.get("/health/ready")
|
||||
async def readiness_check():
|
||||
ready = getattr(app.state, 'ready', True)
|
||||
return {"status": "ready" if ready else "not ready"}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"service": "training-service", "version": "1.0.0"}
|
||||
|
||||
Reference in New Issue
Block a user