Add base kubernetes support final fix 4
This commit is contained in:
@@ -91,7 +91,27 @@ class ApiClient {
|
||||
|
||||
// Response interceptor for error handling and automatic token refresh
|
||||
this.client.interceptors.response.use(
|
||||
(response) => response,
|
||||
(response) => {
|
||||
// Enhanced logging for token refresh header detection
|
||||
const refreshSuggested = response.headers['x-token-refresh-suggested'];
|
||||
if (refreshSuggested) {
|
||||
console.log('🔍 TOKEN REFRESH HEADER DETECTED:', {
|
||||
url: response.config?.url,
|
||||
method: response.config?.method,
|
||||
status: response.status,
|
||||
refreshSuggested,
|
||||
hasRefreshToken: !!this.refreshToken,
|
||||
currentTokenLength: this.authToken?.length || 0
|
||||
});
|
||||
}
|
||||
|
||||
// Check if server suggests token refresh
|
||||
if (refreshSuggested === 'true' && this.refreshToken) {
|
||||
console.log('🔄 Server suggests token refresh - refreshing proactively');
|
||||
this.proactiveTokenRefresh();
|
||||
}
|
||||
return response;
|
||||
},
|
||||
async (error) => {
|
||||
const originalRequest = error.config;
|
||||
|
||||
@@ -228,6 +248,40 @@ class ApiClient {
|
||||
}
|
||||
}
|
||||
|
||||
private async proactiveTokenRefresh() {
|
||||
// Avoid multiple simultaneous proactive refreshes
|
||||
if (this.isRefreshing) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
this.isRefreshing = true;
|
||||
console.log('🔄 Proactively refreshing token...');
|
||||
|
||||
const response = await this.client.post('/auth/refresh', {
|
||||
refresh_token: this.refreshToken
|
||||
});
|
||||
|
||||
const { access_token, refresh_token } = response.data;
|
||||
|
||||
// Update tokens
|
||||
this.setAuthToken(access_token);
|
||||
if (refresh_token) {
|
||||
this.setRefreshToken(refresh_token);
|
||||
}
|
||||
|
||||
// Update auth store
|
||||
await this.updateAuthStore(access_token, refresh_token);
|
||||
|
||||
console.log('✅ Proactive token refresh successful');
|
||||
} catch (error) {
|
||||
console.warn('⚠️ Proactive token refresh failed:', error);
|
||||
// Don't handle as auth failure here - let the next 401 handle it
|
||||
} finally {
|
||||
this.isRefreshing = false;
|
||||
}
|
||||
}
|
||||
|
||||
private async handleAuthFailure() {
|
||||
try {
|
||||
// Clear tokens
|
||||
@@ -275,6 +329,117 @@ class ApiClient {
|
||||
return this.tenantId;
|
||||
}
|
||||
|
||||
// Token synchronization methods for WebSocket connections
|
||||
getCurrentValidToken(): string | null {
|
||||
return this.authToken;
|
||||
}
|
||||
|
||||
async ensureValidToken(): Promise<string | null> {
|
||||
const originalToken = this.authToken;
|
||||
const originalTokenShort = originalToken ? `${originalToken.slice(0, 20)}...${originalToken.slice(-10)}` : 'null';
|
||||
|
||||
console.log('🔍 ensureValidToken() called:', {
|
||||
hasToken: !!this.authToken,
|
||||
tokenPreview: originalTokenShort,
|
||||
isRefreshing: this.isRefreshing,
|
||||
hasRefreshToken: !!this.refreshToken
|
||||
});
|
||||
|
||||
// If we have a valid token, return it
|
||||
if (this.authToken && !this.isTokenNearExpiry(this.authToken)) {
|
||||
const expiryInfo = this.getTokenExpiryInfo(this.authToken);
|
||||
console.log('✅ Token is valid, returning current token:', {
|
||||
tokenPreview: originalTokenShort,
|
||||
expiryInfo
|
||||
});
|
||||
return this.authToken;
|
||||
}
|
||||
|
||||
// If token is near expiry or expired, try to refresh
|
||||
if (this.refreshToken && !this.isRefreshing) {
|
||||
console.log('🔄 Token needs refresh, attempting proactive refresh:', {
|
||||
reason: this.authToken ? 'near expiry' : 'no token',
|
||||
expiryInfo: this.authToken ? this.getTokenExpiryInfo(this.authToken) : 'N/A'
|
||||
});
|
||||
|
||||
try {
|
||||
await this.proactiveTokenRefresh();
|
||||
const newTokenShort = this.authToken ? `${this.authToken.slice(0, 20)}...${this.authToken.slice(-10)}` : 'null';
|
||||
const tokenChanged = originalToken !== this.authToken;
|
||||
|
||||
console.log('✅ Token refresh completed:', {
|
||||
tokenChanged,
|
||||
oldTokenPreview: originalTokenShort,
|
||||
newTokenPreview: newTokenShort,
|
||||
newExpiryInfo: this.authToken ? this.getTokenExpiryInfo(this.authToken) : 'N/A'
|
||||
});
|
||||
|
||||
return this.authToken;
|
||||
} catch (error) {
|
||||
console.warn('❌ Failed to refresh token in ensureValidToken:', error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
console.log('⚠️ Returning current token without refresh:', {
|
||||
reason: this.isRefreshing ? 'already refreshing' : 'no refresh token',
|
||||
tokenPreview: originalTokenShort
|
||||
});
|
||||
return this.authToken;
|
||||
}
|
||||
|
||||
private getTokenExpiryInfo(token: string): any {
|
||||
try {
|
||||
const payload = JSON.parse(atob(token.split('.')[1]));
|
||||
const exp = payload.exp;
|
||||
const iat = payload.iat;
|
||||
if (!exp) return { error: 'No expiry in token' };
|
||||
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
const timeUntilExpiry = exp - now;
|
||||
const tokenLifetime = exp - iat;
|
||||
|
||||
return {
|
||||
issuedAt: new Date(iat * 1000).toISOString(),
|
||||
expiresAt: new Date(exp * 1000).toISOString(),
|
||||
lifetimeMinutes: Math.floor(tokenLifetime / 60),
|
||||
secondsUntilExpiry: timeUntilExpiry,
|
||||
minutesUntilExpiry: Math.floor(timeUntilExpiry / 60),
|
||||
isNearExpiry: timeUntilExpiry < 300,
|
||||
isExpired: timeUntilExpiry <= 0
|
||||
};
|
||||
} catch (error) {
|
||||
return { error: 'Failed to parse token', details: error };
|
||||
}
|
||||
}
|
||||
|
||||
private isTokenNearExpiry(token: string): boolean {
|
||||
try {
|
||||
const payload = JSON.parse(atob(token.split('.')[1]));
|
||||
const exp = payload.exp;
|
||||
if (!exp) return false;
|
||||
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
const timeUntilExpiry = exp - now;
|
||||
|
||||
// Consider token near expiry if less than 5 minutes remaining
|
||||
const isNear = timeUntilExpiry < 300;
|
||||
|
||||
if (isNear) {
|
||||
console.log('⏰ Token is near expiry:', {
|
||||
secondsUntilExpiry: timeUntilExpiry,
|
||||
minutesUntilExpiry: Math.floor(timeUntilExpiry / 60),
|
||||
expiresAt: new Date(exp * 1000).toISOString()
|
||||
});
|
||||
}
|
||||
|
||||
return isNear;
|
||||
} catch (error) {
|
||||
console.warn('Failed to parse token for expiry check:', error);
|
||||
return true; // Assume expired if we can't parse
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP Methods - Return direct data for React Query
|
||||
async get<T = any>(url: string, config?: AxiosRequestConfig): Promise<T> {
|
||||
const response: AxiosResponse<T> = await this.client.get(url, config);
|
||||
|
||||
@@ -105,9 +105,20 @@ export const useMarkStepCompleted = (
|
||||
mutationFn: ({ userId, stepName, data }) =>
|
||||
onboardingService.markStepCompleted(userId, stepName, data),
|
||||
onSuccess: (data, { userId }) => {
|
||||
// Update progress cache
|
||||
// Update progress cache with new data
|
||||
queryClient.setQueryData(onboardingKeys.progress(userId), data);
|
||||
|
||||
// Invalidate the query to ensure fresh data on next access
|
||||
queryClient.invalidateQueries({ queryKey: onboardingKeys.progress(userId) });
|
||||
},
|
||||
onError: (error, { userId, stepName }) => {
|
||||
console.error(`Failed to complete step ${stepName} for user ${userId}:`, error);
|
||||
|
||||
// Invalidate queries on error to ensure we get fresh data
|
||||
queryClient.invalidateQueries({ queryKey: onboardingKeys.progress(userId) });
|
||||
},
|
||||
// Prevent duplicate requests by using the step name as a mutation key
|
||||
mutationKey: (variables) => ['markStepCompleted', variables?.userId, variables?.stepName],
|
||||
...options,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
* Provides data fetching, caching, and state management for training operations
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import * as React from 'react';
|
||||
import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOptions } from '@tanstack/react-query';
|
||||
import { trainingService } from '../services/training';
|
||||
import { ApiError } from '../client/apiClient';
|
||||
import { ApiError, apiClient } from '../client/apiClient';
|
||||
import { useAuthStore } from '../../stores/auth.store';
|
||||
import type {
|
||||
TrainingJobRequest,
|
||||
@@ -53,15 +53,62 @@ export const trainingKeys = {
|
||||
export const useTrainingJobStatus = (
|
||||
tenantId: string,
|
||||
jobId: string,
|
||||
options?: Omit<UseQueryOptions<TrainingJobStatus, ApiError>, 'queryKey' | 'queryFn'>
|
||||
options?: Omit<UseQueryOptions<TrainingJobStatus, ApiError>, 'queryKey' | 'queryFn'> & {
|
||||
isWebSocketConnected?: boolean;
|
||||
}
|
||||
) => {
|
||||
const { isWebSocketConnected, ...queryOptions } = options || {};
|
||||
|
||||
// Completely disable the query when WebSocket is connected
|
||||
const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected;
|
||||
|
||||
console.log('🔄 Training status query:', {
|
||||
tenantId: !!tenantId,
|
||||
jobId: !!jobId,
|
||||
isWebSocketConnected,
|
||||
queryEnabled: isEnabled
|
||||
});
|
||||
|
||||
return useQuery<TrainingJobStatus, ApiError>({
|
||||
queryKey: trainingKeys.jobs.status(tenantId, jobId),
|
||||
queryFn: () => trainingService.getTrainingJobStatus(tenantId, jobId),
|
||||
enabled: !!tenantId && !!jobId,
|
||||
refetchInterval: 5000, // Poll every 5 seconds while training
|
||||
queryFn: () => {
|
||||
console.log('📡 Executing HTTP training status query (WebSocket disconnected)');
|
||||
return trainingService.getTrainingJobStatus(tenantId, jobId);
|
||||
},
|
||||
enabled: isEnabled, // Completely disable when WebSocket connected
|
||||
refetchInterval: (query) => {
|
||||
// CRITICAL FIX: React Query executes refetchInterval even when enabled=false
|
||||
// We must check WebSocket connection state here to prevent misleading polling
|
||||
if (isWebSocketConnected) {
|
||||
console.log('✅ WebSocket connected - HTTP polling DISABLED');
|
||||
return false; // Disable polling when WebSocket is active
|
||||
}
|
||||
|
||||
const data = query.state.data;
|
||||
|
||||
// Stop polling if we get auth errors or training is completed
|
||||
if (query.state.error && (query.state.error as any)?.status === 401) {
|
||||
console.log('🚫 Stopping status polling due to auth error');
|
||||
return false;
|
||||
}
|
||||
if (data?.status === 'completed' || data?.status === 'failed') {
|
||||
console.log('🏁 Training completed - stopping HTTP polling');
|
||||
return false; // Stop polling when training is done
|
||||
}
|
||||
|
||||
console.log('📊 HTTP fallback polling active (WebSocket actually disconnected) - 5s interval');
|
||||
return 5000; // Poll every 5 seconds while training (fallback when WebSocket unavailable)
|
||||
},
|
||||
staleTime: 1000, // Consider data stale after 1 second
|
||||
...options,
|
||||
retry: (failureCount, error) => {
|
||||
// Don't retry on auth errors
|
||||
if ((error as any)?.status === 401) {
|
||||
console.log('🚫 Not retrying due to auth error');
|
||||
return false;
|
||||
}
|
||||
return failureCount < 3;
|
||||
},
|
||||
...queryOptions,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -242,9 +289,9 @@ export const useTrainingWebSocket = (
|
||||
}
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
const authToken = useAuthStore((state) => state.token);
|
||||
const [isConnected, setIsConnected] = React.useState(false);
|
||||
const [connectionError, setConnectionError] = React.useState<string | null>(null);
|
||||
const [connectionAttempts, setConnectionAttempts] = React.useState(0);
|
||||
|
||||
// Memoize options to prevent unnecessary effect re-runs
|
||||
const memoizedOptions = React.useMemo(() => options, [
|
||||
@@ -266,20 +313,44 @@ export const useTrainingWebSocket = (
|
||||
let reconnectAttempts = 0;
|
||||
const maxReconnectAttempts = 3;
|
||||
|
||||
const connect = () => {
|
||||
const connect = async () => {
|
||||
try {
|
||||
setConnectionError(null);
|
||||
const effectiveToken = token || authToken;
|
||||
setConnectionAttempts(prev => prev + 1);
|
||||
|
||||
// Use centralized token management from apiClient
|
||||
let effectiveToken: string | null;
|
||||
|
||||
try {
|
||||
// Always use the apiClient's token management
|
||||
effectiveToken = await apiClient.ensureValidToken();
|
||||
|
||||
if (!effectiveToken) {
|
||||
throw new Error('No valid token available');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('❌ Failed to get valid token for WebSocket:', error);
|
||||
setConnectionError('Authentication failed. Please log in again.');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`🔄 Attempting WebSocket connection (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts + 1}):`, {
|
||||
tenantId,
|
||||
jobId,
|
||||
hasToken: !!effectiveToken
|
||||
hasToken: !!effectiveToken,
|
||||
tokenFromApiClient: true
|
||||
});
|
||||
|
||||
ws = trainingService.createWebSocketConnection(tenantId, jobId, token || authToken || undefined);
|
||||
ws = trainingService.createWebSocketConnection(tenantId, jobId, effectiveToken);
|
||||
|
||||
ws.onopen = () => {
|
||||
console.log('✅ Training WebSocket connected successfully');
|
||||
console.log('✅ Training WebSocket connected successfully', {
|
||||
readyState: ws?.readyState,
|
||||
url: ws?.url,
|
||||
jobId
|
||||
});
|
||||
// Track connection time for debugging
|
||||
(ws as any)._connectTime = Date.now();
|
||||
setIsConnected(true);
|
||||
reconnectAttempts = 0; // Reset on successful connection
|
||||
|
||||
@@ -291,23 +362,81 @@ export const useTrainingWebSocket = (
|
||||
console.warn('Failed to request status on connection:', e);
|
||||
}
|
||||
|
||||
// Set up periodic ping to keep connection alive
|
||||
const pingInterval = setInterval(() => {
|
||||
// Helper function to check if tokens represent different auth sessions
|
||||
const isTokenSessionDifferent = (oldToken: string, newToken: string): boolean => {
|
||||
if (!oldToken || !newToken) return !!oldToken !== !!newToken;
|
||||
|
||||
try {
|
||||
const oldPayload = JSON.parse(atob(oldToken.split('.')[1]));
|
||||
const newPayload = JSON.parse(atob(newToken.split('.')[1]));
|
||||
|
||||
// Compare by issued timestamp (iat) - different iat means new auth session
|
||||
return oldPayload.iat !== newPayload.iat;
|
||||
} catch (e) {
|
||||
console.warn('Failed to parse token for session comparison, falling back to string comparison:', e);
|
||||
return oldToken !== newToken;
|
||||
}
|
||||
};
|
||||
|
||||
// Set up periodic ping and intelligent token refresh detection
|
||||
const heartbeatInterval = setInterval(async () => {
|
||||
if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) {
|
||||
try {
|
||||
// Check token validity (this may refresh if needed)
|
||||
const currentToken = await apiClient.ensureValidToken();
|
||||
|
||||
// Enhanced token change detection with detailed logging
|
||||
const tokenStringChanged = currentToken !== effectiveToken;
|
||||
const tokenSessionChanged = currentToken && effectiveToken ?
|
||||
isTokenSessionDifferent(effectiveToken, currentToken) : tokenStringChanged;
|
||||
|
||||
console.log('🔍 WebSocket token validation check:', {
|
||||
hasCurrentToken: !!currentToken,
|
||||
hasEffectiveToken: !!effectiveToken,
|
||||
tokenStringChanged,
|
||||
tokenSessionChanged,
|
||||
currentTokenPreview: currentToken ? `${currentToken.slice(0, 20)}...${currentToken.slice(-10)}` : 'null',
|
||||
effectiveTokenPreview: effectiveToken ? `${effectiveToken.slice(0, 20)}...${effectiveToken.slice(-10)}` : 'null'
|
||||
});
|
||||
|
||||
// Only reconnect if we have a genuine session change (different iat)
|
||||
if (tokenSessionChanged) {
|
||||
console.log('🔄 Token session changed - reconnecting WebSocket with new session token');
|
||||
console.log('📊 Session change details:', {
|
||||
reason: !currentToken ? 'token removed' :
|
||||
!effectiveToken ? 'token added' : 'new auth session',
|
||||
oldTokenIat: effectiveToken ? (() => {
|
||||
try { return JSON.parse(atob(effectiveToken.split('.')[1])).iat; } catch { return 'parse-error'; }
|
||||
})() : 'N/A',
|
||||
newTokenIat: currentToken ? (() => {
|
||||
try { return JSON.parse(atob(currentToken.split('.')[1])).iat; } catch { return 'parse-error'; }
|
||||
})() : 'N/A'
|
||||
});
|
||||
|
||||
// Close current connection and trigger reconnection with new token
|
||||
ws?.close(1000, 'Token session changed - reconnecting');
|
||||
clearInterval(heartbeatInterval);
|
||||
return;
|
||||
} else if (tokenStringChanged) {
|
||||
console.log('ℹ️ Token string changed but same session - continuing with current connection');
|
||||
// Update effective token reference for future comparisons
|
||||
effectiveToken = currentToken;
|
||||
}
|
||||
|
||||
console.log('✅ Token validated during heartbeat - same session');
|
||||
ws?.send('ping');
|
||||
console.log('💓 Sent ping to server');
|
||||
console.log('💓 Sent ping to server (token session validated)');
|
||||
} catch (e) {
|
||||
console.warn('Failed to send ping:', e);
|
||||
clearInterval(pingInterval);
|
||||
console.warn('Failed to send ping or validate token:', e);
|
||||
clearInterval(heartbeatInterval);
|
||||
}
|
||||
} else {
|
||||
clearInterval(pingInterval);
|
||||
clearInterval(heartbeatInterval);
|
||||
}
|
||||
}, 30000); // Ping every 30 seconds
|
||||
}, 30000); // Check every 30 seconds for token refresh and send ping
|
||||
|
||||
// Store interval for cleanup
|
||||
(ws as any).pingInterval = pingInterval;
|
||||
(ws as any).heartbeatInterval = heartbeatInterval;
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
@@ -404,13 +533,22 @@ export const useTrainingWebSocket = (
|
||||
};
|
||||
|
||||
ws.onclose = (event) => {
|
||||
console.log(`❌ Training WebSocket disconnected. Code: ${event.code}, Reason: "${event.reason}"`);
|
||||
console.log(`❌ Training WebSocket disconnected. Code: ${event.code}, Reason: "${event.reason}"`, {
|
||||
wasClean: event.wasClean,
|
||||
jobId,
|
||||
timeConnected: ws ? `${Date.now() - (ws as any)._connectTime || 0}ms` : 'unknown',
|
||||
reconnectAttempts
|
||||
});
|
||||
setIsConnected(false);
|
||||
|
||||
// Detailed logging for different close codes
|
||||
switch (event.code) {
|
||||
case 1000:
|
||||
if (event.reason === 'Token refreshed - reconnecting') {
|
||||
console.log('🔄 WebSocket closed for token refresh - will reconnect immediately');
|
||||
} else {
|
||||
console.log('🔒 WebSocket closed normally');
|
||||
}
|
||||
break;
|
||||
case 1006:
|
||||
console.log('⚠️ WebSocket closed abnormally (1006) - likely server-side issue or network problem');
|
||||
@@ -425,6 +563,15 @@ export const useTrainingWebSocket = (
|
||||
console.log(`❓ WebSocket closed with code ${event.code}`);
|
||||
}
|
||||
|
||||
// Handle token refresh reconnection (immediate reconnect)
|
||||
if (event.code === 1000 && event.reason === 'Token refreshed - reconnecting') {
|
||||
console.log('🔄 Reconnecting immediately due to token refresh...');
|
||||
reconnectTimer = setTimeout(() => {
|
||||
connect(); // Reconnect immediately with fresh token
|
||||
}, 1000); // Short delay to allow cleanup
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to reconnect if not manually disconnected and haven't exceeded max attempts
|
||||
if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) {
|
||||
const delay = Math.min(1000 * Math.pow(2, reconnectAttempts), 10000); // Exponential backoff, max 10s
|
||||
@@ -470,7 +617,7 @@ export const useTrainingWebSocket = (
|
||||
|
||||
setIsConnected(false);
|
||||
};
|
||||
}, [tenantId, jobId, token, authToken, queryClient, memoizedOptions]);
|
||||
}, [tenantId, jobId, queryClient, memoizedOptions]);
|
||||
|
||||
return {
|
||||
isConnected,
|
||||
@@ -479,17 +626,27 @@ export const useTrainingWebSocket = (
|
||||
};
|
||||
|
||||
// Utility Hooks
|
||||
export const useIsTrainingInProgress = (tenantId: string, jobId?: string) => {
|
||||
export const useIsTrainingInProgress = (
|
||||
tenantId: string,
|
||||
jobId?: string,
|
||||
isWebSocketConnected?: boolean
|
||||
) => {
|
||||
const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', {
|
||||
enabled: !!jobId,
|
||||
isWebSocketConnected,
|
||||
});
|
||||
|
||||
return jobStatus?.status === 'running' || jobStatus?.status === 'pending';
|
||||
};
|
||||
|
||||
export const useTrainingProgress = (tenantId: string, jobId?: string) => {
|
||||
export const useTrainingProgress = (
|
||||
tenantId: string,
|
||||
jobId?: string,
|
||||
isWebSocketConnected?: boolean
|
||||
) => {
|
||||
const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', {
|
||||
enabled: !!jobId,
|
||||
isWebSocketConnected,
|
||||
});
|
||||
|
||||
return {
|
||||
|
||||
@@ -94,12 +94,15 @@ export const OnboardingWizard: React.FC = () => {
|
||||
const { setCurrentTenant } = useTenantActions();
|
||||
|
||||
// Auto-complete user_registered step if needed (runs first)
|
||||
const [autoCompletionAttempted, setAutoCompletionAttempted] = React.useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (userProgress && user?.id) {
|
||||
if (userProgress && user?.id && !autoCompletionAttempted && !markStepCompleted.isPending) {
|
||||
const userRegisteredStep = userProgress.steps.find(s => s.step_name === 'user_registered');
|
||||
|
||||
if (!userRegisteredStep?.completed) {
|
||||
console.log('🔄 Auto-completing user_registered step for new user...');
|
||||
setAutoCompletionAttempted(true);
|
||||
|
||||
markStepCompleted.mutate({
|
||||
userId: user.id,
|
||||
@@ -116,11 +119,13 @@ export const OnboardingWizard: React.FC = () => {
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error('❌ Failed to auto-complete user_registered step:', error);
|
||||
// Reset flag on error to allow retry
|
||||
setAutoCompletionAttempted(false);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}, [userProgress, user?.id, markStepCompleted]);
|
||||
}, [userProgress, user?.id, autoCompletionAttempted, markStepCompleted.isPending]); // Removed markStepCompleted from deps
|
||||
|
||||
// Initialize step index based on backend progress with validation
|
||||
useEffect(() => {
|
||||
@@ -205,6 +210,12 @@ export const OnboardingWizard: React.FC = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prevent concurrent mutations
|
||||
if (markStepCompleted.isPending) {
|
||||
console.warn(`⚠️ Step completion already in progress for "${currentStep.id}", skipping duplicate call`);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`🎯 Completing step: "${currentStep.id}" with data:`, data);
|
||||
|
||||
try {
|
||||
@@ -267,6 +278,31 @@ export const OnboardingWizard: React.FC = () => {
|
||||
|
||||
console.error(`📊 Error details: Status ${statusCode}, Message: ${errorMessage}`);
|
||||
|
||||
// Handle different types of errors
|
||||
if (statusCode === 207) {
|
||||
// Multi-Status: Step updated but summary failed
|
||||
console.warn(`⚠️ Partial success for step "${currentStep.id}": ${errorMessage}`);
|
||||
|
||||
// Continue with step advancement since the actual step was completed
|
||||
if (currentStep.id === 'completion') {
|
||||
// Navigate to dashboard after completion
|
||||
if (isNewTenant) {
|
||||
navigate('/app/dashboard');
|
||||
} else {
|
||||
navigate('/app');
|
||||
}
|
||||
} else {
|
||||
// Auto-advance to next step after successful completion
|
||||
if (currentStepIndex < STEPS.length - 1) {
|
||||
setCurrentStepIndex(currentStepIndex + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Show a warning but don't block progress
|
||||
console.warn(`Step "${currentStep.title}" completed with warnings: ${errorMessage}`);
|
||||
return; // Don't show error alert
|
||||
}
|
||||
|
||||
// Check if it's a dependency error
|
||||
if (errorMessage.includes('dependencies not met')) {
|
||||
console.error('🚫 Dependencies not met for step:', currentStep.id);
|
||||
@@ -278,7 +314,7 @@ export const OnboardingWizard: React.FC = () => {
|
||||
}
|
||||
}
|
||||
|
||||
// Don't advance automatically on error - user should see the issue
|
||||
// Don't advance automatically on real errors - user should see the issue
|
||||
alert(`${t('onboarding:errors.step_failed', 'Error al completar paso')} "${currentStep.title}": ${errorMessage}`);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import React, { useState, useCallback, useEffect } from 'react';
|
||||
import { Button } from '../../../ui/Button';
|
||||
import { useCurrentTenant } from '../../../../stores/tenant.store';
|
||||
import { useCreateTrainingJob, useTrainingWebSocket } from '../../../../api/hooks/training';
|
||||
import { useCreateTrainingJob, useTrainingWebSocket, useTrainingJobStatus } from '../../../../api/hooks/training';
|
||||
|
||||
interface MLTrainingStepProps {
|
||||
onNext: () => void;
|
||||
@@ -85,6 +85,63 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
|
||||
} : undefined
|
||||
);
|
||||
|
||||
// Smart fallback polling - automatically disabled when WebSocket is connected
|
||||
const { data: jobStatus } = useTrainingJobStatus(
|
||||
currentTenant?.id || '',
|
||||
jobId || '',
|
||||
{
|
||||
enabled: !!jobId && !!currentTenant?.id,
|
||||
isWebSocketConnected: isConnected, // This will disable HTTP polling when WebSocket is connected
|
||||
}
|
||||
);
|
||||
|
||||
// Handle training status updates from HTTP polling (fallback only)
|
||||
useEffect(() => {
|
||||
if (!jobStatus || !jobId || trainingProgress?.stage === 'completed') {
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('📊 HTTP fallback status update:', jobStatus);
|
||||
|
||||
// Check if training completed via HTTP polling fallback
|
||||
if (jobStatus.status === 'completed' && trainingProgress?.stage !== 'completed') {
|
||||
console.log('✅ Training completion detected via HTTP fallback');
|
||||
setTrainingProgress({
|
||||
stage: 'completed',
|
||||
progress: 100,
|
||||
message: 'Entrenamiento completado exitosamente (detectado por verificación HTTP)'
|
||||
});
|
||||
setIsTraining(false);
|
||||
|
||||
setTimeout(() => {
|
||||
onComplete({
|
||||
jobId: jobId,
|
||||
success: true,
|
||||
message: 'Modelo entrenado correctamente',
|
||||
detectedViaPolling: true
|
||||
});
|
||||
}, 2000);
|
||||
} else if (jobStatus.status === 'failed') {
|
||||
console.log('❌ Training failure detected via HTTP fallback');
|
||||
setError('Error detectado durante el entrenamiento (verificación de estado)');
|
||||
setIsTraining(false);
|
||||
setTrainingProgress(null);
|
||||
} else if (jobStatus.status === 'running' && jobStatus.progress !== undefined) {
|
||||
// Update progress if we have newer information from HTTP polling fallback
|
||||
const currentProgress = trainingProgress?.progress || 0;
|
||||
if (jobStatus.progress > currentProgress) {
|
||||
console.log(`📈 Progress update via HTTP fallback: ${jobStatus.progress}%`);
|
||||
setTrainingProgress(prev => ({
|
||||
...prev,
|
||||
stage: 'training',
|
||||
progress: jobStatus.progress,
|
||||
message: jobStatus.message || 'Entrenando modelo...',
|
||||
currentStep: jobStatus.current_step
|
||||
}) as TrainingProgress);
|
||||
}
|
||||
}
|
||||
}, [jobStatus, jobId, trainingProgress?.stage, onComplete]);
|
||||
|
||||
// Auto-trigger training when component mounts
|
||||
useEffect(() => {
|
||||
if (currentTenant?.id && !isTraining && !trainingProgress && !error) {
|
||||
|
||||
@@ -249,7 +249,7 @@ async def events_stream(request: Request, token: str):
|
||||
|
||||
@app.websocket("/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||||
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
|
||||
"""WebSocket proxy that forwards connections directly to training service"""
|
||||
"""WebSocket proxy that forwards connections directly to training service with enhanced token validation"""
|
||||
await websocket.accept()
|
||||
|
||||
# Get token from query params
|
||||
@@ -259,6 +259,29 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
# Validate token using auth middleware
|
||||
from app.middleware.auth import jwt_handler
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid authentication token")
|
||||
return
|
||||
|
||||
# Check token expiration
|
||||
import time
|
||||
if payload.get('exp', 0) < time.time():
|
||||
logger.warning(f"WebSocket connection rejected - expired token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Token expired")
|
||||
return
|
||||
|
||||
logger.info(f"WebSocket token validated for user {payload.get('email', 'unknown')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
|
||||
await websocket.close(code=1008, reason="Token validation failed")
|
||||
return
|
||||
|
||||
logger.info(f"Proxying WebSocket connection to training service for job {job_id}, tenant {tenant_id}")
|
||||
|
||||
# Build WebSocket URL to training service
|
||||
@@ -266,43 +289,153 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
|
||||
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
|
||||
training_ws_url = f"{training_ws_url}/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
|
||||
|
||||
training_ws = None
|
||||
heartbeat_task = None
|
||||
|
||||
try:
|
||||
# Connect to training service WebSocket
|
||||
# Connect to training service WebSocket with proper timeout configuration
|
||||
import websockets
|
||||
async with websockets.connect(training_ws_url) as training_ws:
|
||||
logger.info(f"Connected to training service WebSocket for job {job_id}")
|
||||
|
||||
# Configure timeouts to coordinate with frontend (30s heartbeat) and training service
|
||||
# DISABLE gateway-level ping to avoid dual-ping conflicts - let frontend handle ping/pong
|
||||
training_ws = await websockets.connect(
|
||||
training_ws_url,
|
||||
ping_interval=None, # DISABLED: Let frontend handle ping/pong via message forwarding
|
||||
ping_timeout=None, # DISABLED: No independent ping mechanism
|
||||
close_timeout=15, # Reasonable close timeout
|
||||
max_size=2**20, # 1MB max message size
|
||||
max_queue=32 # Max queued messages
|
||||
)
|
||||
|
||||
logger.info(f"Connected to training service WebSocket for job {job_id} with gateway ping DISABLED (frontend handles ping/pong)")
|
||||
|
||||
# Track connection state properly due to FastAPI WebSocket state propagation bug
|
||||
connection_alive = True
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
async def check_connection_health():
|
||||
"""Monitor connection health based on activity timestamps only - no WebSocket interference"""
|
||||
nonlocal connection_alive, last_activity
|
||||
|
||||
while connection_alive:
|
||||
try:
|
||||
await asyncio.sleep(30) # Check every 30 seconds (aligned with frontend heartbeat)
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Check if we haven't received any activity for too long
|
||||
# Frontend sends ping every 30s, so 90s = 3 missed pings before considering dead
|
||||
if current_time - last_activity > 90:
|
||||
logger.warning(f"No frontend activity for 90s on job {job_id} - connection may be dead")
|
||||
# Don't forcibly close - let the forwarding loops handle actual connection issues
|
||||
# This is just monitoring/logging now
|
||||
else:
|
||||
logger.debug(f"Connection health OK for job {job_id} - last activity {int(current_time - last_activity)}s ago")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Connection health monitoring error for job {job_id}: {e}")
|
||||
break
|
||||
|
||||
async def forward_to_training():
|
||||
"""Forward messages from frontend to training service"""
|
||||
"""Forward messages from frontend to training service with proper error handling"""
|
||||
nonlocal connection_alive, last_activity
|
||||
|
||||
try:
|
||||
async for message in websocket.iter_text():
|
||||
while connection_alive and training_ws and training_ws.open:
|
||||
try:
|
||||
# Use longer timeout to avoid conflicts with frontend 30s heartbeat
|
||||
# Frontend sends ping every 30s, so we need to allow for some latency
|
||||
message = await asyncio.wait_for(websocket.receive_text(), timeout=45.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
# Forward the message to training service
|
||||
await training_ws.send(message)
|
||||
logger.debug(f"Forwarded message to training service for job {job_id}: {message[:100]}...")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No message received in 45 seconds, continue loop
|
||||
# This allows for frontend 30s heartbeat + network latency + processing time
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to training service: {e}")
|
||||
logger.error(f"Error receiving from frontend for job {job_id}: {e}")
|
||||
connection_alive = False
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in forward_to_training for job {job_id}: {e}")
|
||||
connection_alive = False
|
||||
|
||||
async def forward_to_frontend():
|
||||
"""Forward messages from training service to frontend"""
|
||||
try:
|
||||
async for message in training_ws:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to frontend: {e}")
|
||||
"""Forward messages from training service to frontend with proper error handling"""
|
||||
nonlocal connection_alive, last_activity
|
||||
|
||||
# Run both forwarding tasks concurrently
|
||||
try:
|
||||
while connection_alive and training_ws and training_ws.open:
|
||||
try:
|
||||
# Use coordinated timeout - training service expects messages every 60s
|
||||
# This should be longer than training service timeout to avoid premature closure
|
||||
message = await asyncio.wait_for(training_ws.recv(), timeout=75.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
# Forward the message to frontend
|
||||
await websocket.send_text(message)
|
||||
logger.debug(f"Forwarded message to frontend for job {job_id}: {message[:100]}...")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No message received in 75 seconds, continue loop
|
||||
# Training service sends heartbeats, so this indicates potential issues
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving from training service for job {job_id}: {e}")
|
||||
connection_alive = False
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in forward_to_frontend for job {job_id}: {e}")
|
||||
connection_alive = False
|
||||
|
||||
# Start connection health monitoring
|
||||
heartbeat_task = asyncio.create_task(check_connection_health())
|
||||
|
||||
# Run both forwarding tasks concurrently with proper error handling
|
||||
try:
|
||||
await asyncio.gather(
|
||||
forward_to_training(),
|
||||
forward_to_frontend(),
|
||||
return_exceptions=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket forwarding tasks for job {job_id}: {e}")
|
||||
finally:
|
||||
connection_alive = False
|
||||
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
logger.warning(f"Training service WebSocket connection closed for job {job_id}: {e}")
|
||||
except websockets.exceptions.WebSocketException as e:
|
||||
logger.error(f"WebSocket exception for job {job_id}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket proxy error for job {job_id}: {e}")
|
||||
try:
|
||||
await websocket.close(code=1011, reason="Training service connection failed")
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
logger.info(f"WebSocket proxy closed for job {job_id}")
|
||||
# Cleanup
|
||||
if heartbeat_task and not heartbeat_task.done():
|
||||
heartbeat_task.cancel()
|
||||
try:
|
||||
await heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if training_ws and not training_ws.closed:
|
||||
try:
|
||||
await training_ws.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing training service WebSocket for job {job_id}: {e}")
|
||||
|
||||
try:
|
||||
if not websocket.client_state.name == 'DISCONNECTED':
|
||||
await websocket.close(code=1000, reason="Proxy connection closed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing frontend WebSocket for job {job_id}: {e}")
|
||||
|
||||
logger.info(f"WebSocket proxy cleanup completed for job {job_id}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -67,7 +67,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
|
||||
# ✅ STEP 2: Verify token and get user context
|
||||
user_context = await self._verify_token(token)
|
||||
user_context = await self._verify_token(token, request)
|
||||
if not user_context:
|
||||
logger.warning(f"Invalid token for route: {request.url.path}")
|
||||
return JSONResponse(
|
||||
@@ -117,7 +117,14 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
tenant_id=tenant_id,
|
||||
path=request.url.path)
|
||||
|
||||
return await call_next(request)
|
||||
# Process the request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add token expiry warning header if token is near expiry
|
||||
if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry:
|
||||
response.headers["X-Token-Refresh-Suggested"] = "true"
|
||||
|
||||
return response
|
||||
|
||||
def _is_public_route(self, path: str) -> bool:
|
||||
"""Check if route requires authentication"""
|
||||
@@ -130,7 +137,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
return auth_header.split(" ")[1]
|
||||
return None
|
||||
|
||||
async def _verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify JWT token with improved fallback strategy
|
||||
FIXED: Better error handling and token structure validation
|
||||
@@ -141,6 +148,17 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if payload and self._validate_token_payload(payload):
|
||||
logger.debug("Token validated locally")
|
||||
|
||||
# Check if token is near expiry and set flag for response header
|
||||
if request:
|
||||
import time
|
||||
exp_time = payload.get("exp", 0)
|
||||
current_time = time.time()
|
||||
time_until_expiry = exp_time - current_time
|
||||
|
||||
if time_until_expiry < 300: # 5 minutes
|
||||
request.state.token_near_expiry = True
|
||||
|
||||
# Convert JWT payload to user context format
|
||||
return self._jwt_payload_to_user_context(payload)
|
||||
except Exception as e:
|
||||
@@ -188,6 +206,14 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
logger.warning(f"Invalid token type: {payload.get('type')}")
|
||||
return False
|
||||
|
||||
# Check if token is near expiry (within 5 minutes) and log warning
|
||||
import time
|
||||
exp_time = payload.get("exp", 0)
|
||||
current_time = time.time()
|
||||
time_until_expiry = exp_time - current_time
|
||||
|
||||
if time_until_expiry < 300: # 5 minutes
|
||||
logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ data:
|
||||
# AUTHENTICATION & SECURITY SETTINGS
|
||||
# ================================================================
|
||||
JWT_ALGORITHM: "HS256"
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: "30"
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: "240"
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: "7"
|
||||
ENABLE_SERVICE_AUTH: "false"
|
||||
PASSWORD_MIN_LENGTH: "8"
|
||||
|
||||
@@ -140,6 +140,19 @@ class OnboardingService:
|
||||
}
|
||||
)
|
||||
|
||||
# Try to update summary and handle partial failures gracefully
|
||||
try:
|
||||
# Update the user's onboarding summary
|
||||
await self._update_user_summary(user_id)
|
||||
except HTTPException as he:
|
||||
# If it's a 207 Multi-Status (partial success), log warning but continue
|
||||
if he.status_code == status.HTTP_207_MULTI_STATUS:
|
||||
logger.warning(f"Summary update failed for user {user_id}, step {step_name}: {he.detail}")
|
||||
# Continue execution - the step update was successful
|
||||
else:
|
||||
# Re-raise other HTTP exceptions
|
||||
raise
|
||||
|
||||
# Return updated progress
|
||||
return await self.get_user_progress(user_id)
|
||||
|
||||
@@ -285,9 +298,6 @@ class OnboardingService:
|
||||
step_data=data_payload
|
||||
)
|
||||
|
||||
# Update the user's onboarding summary
|
||||
await self._update_user_summary(user_id)
|
||||
|
||||
logger.info(f"Successfully updated onboarding step for user {user_id}: {step_name} = {step_data}")
|
||||
return updated_step
|
||||
|
||||
@@ -330,9 +340,17 @@ class OnboardingService:
|
||||
steps_completed_count=steps_completed_count
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully updated onboarding summary for user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating onboarding summary for user {user_id}: {e}")
|
||||
# Don't raise here - summary update failure shouldn't break step updates
|
||||
logger.error(f"Error updating onboarding summary for user {user_id}: {e}",
|
||||
extra={"user_id": user_id, "error_type": type(e).__name__})
|
||||
# Raise a warning-level HTTPException to inform frontend without breaking the flow
|
||||
# This allows the step update to succeed while alerting about summary issues
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_207_MULTI_STATUS,
|
||||
detail=f"Step updated successfully, but summary update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# API Routes
|
||||
|
||||
|
||||
@@ -61,11 +61,11 @@ class UserOnboardingSummary(Base):
|
||||
# Summary fields
|
||||
current_step = Column(String(50), nullable=False, default="user_registered")
|
||||
next_step = Column(String(50))
|
||||
completion_percentage = Column(String(10), default="0.0") # Store as string for precision
|
||||
completion_percentage = Column(String(50), default="0.0") # Store as string for precision
|
||||
fully_completed = Column(Boolean, default=False)
|
||||
|
||||
# Progress tracking
|
||||
steps_completed_count = Column(String(10), default="0") # Store as string: "3/5"
|
||||
steps_completed_count = Column(String(50), default="0") # Store as string: "3/5"
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@@ -82,10 +82,44 @@ async def training_progress_websocket(
|
||||
tenant_id: str,
|
||||
job_id: str
|
||||
):
|
||||
connection_id = f"{tenant_id}_{id(websocket)}"
|
||||
# Validate token from query parameters
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
# Validate the token (use the same JWT handler as gateway)
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid authentication token")
|
||||
return
|
||||
|
||||
# Verify user has access to this tenant
|
||||
user_id = payload.get('user_id')
|
||||
if not user_id:
|
||||
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid token payload")
|
||||
return
|
||||
|
||||
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
|
||||
await websocket.close(code=1008, reason="Token validation failed")
|
||||
return
|
||||
|
||||
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
|
||||
|
||||
await connection_manager.connect(websocket, job_id, connection_id)
|
||||
logger.info(f"WebSocket connection established for job {job_id}")
|
||||
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
|
||||
|
||||
consumer_task = None
|
||||
training_completed = False
|
||||
@@ -100,9 +134,10 @@ async def training_progress_websocket(
|
||||
|
||||
while not training_completed:
|
||||
try:
|
||||
# FIXED: Use receive() instead of receive_text()
|
||||
# Coordinate with frontend 30s heartbeat + gateway 45s timeout
|
||||
# This should be longer than gateway timeout to avoid premature closure
|
||||
try:
|
||||
data = await asyncio.wait_for(websocket.receive(), timeout=30.0)
|
||||
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
# Handle different message types
|
||||
@@ -125,29 +160,39 @@ async def training_progress_websocket(
|
||||
break
|
||||
|
||||
elif "bytes" in data:
|
||||
# Handle binary messages (WebSocket ping frames)
|
||||
# Handle binary messages (WebSocket ping frames) - respond with text pong for compatibility
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Binary ping received for job {job_id}")
|
||||
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
|
||||
|
||||
elif data["type"] == "websocket.disconnect":
|
||||
logger.info(f"WebSocket disconnect message received for job {job_id}")
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No message received in 30 seconds - send heartbeat
|
||||
# No message received in 60 seconds - this is now coordinated with gateway timeouts
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
if current_time - last_activity > 60:
|
||||
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat")
|
||||
|
||||
# Send heartbeat only if we haven't received frontend ping for too long
|
||||
# Frontend sends ping every 30s, so 60s timeout + 30s grace = 90s before heartbeat
|
||||
if current_time - last_activity > 90: # 90 seconds of total inactivity
|
||||
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
|
||||
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "heartbeat",
|
||||
"job_id": job_id,
|
||||
"timestamp": str(datetime.datetime.now())
|
||||
"timestamp": str(datetime.datetime.now()),
|
||||
"message": "Training service heartbeat - frontend inactive",
|
||||
"inactivity_seconds": int(current_time - last_activity)
|
||||
})
|
||||
last_activity = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
|
||||
break
|
||||
else:
|
||||
# Normal timeout, frontend should be sending ping every 30s
|
||||
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket client disconnected for job {job_id}")
|
||||
|
||||
@@ -48,6 +48,66 @@ async def get_db_health() -> bool:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return False
|
||||
|
||||
async def get_comprehensive_db_health() -> dict:
|
||||
"""
|
||||
Comprehensive health check that verifies both connectivity and table existence
|
||||
"""
|
||||
health_status = {
|
||||
"status": "healthy",
|
||||
"connectivity": False,
|
||||
"tables_exist": False,
|
||||
"tables_verified": [],
|
||||
"missing_tables": [],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
try:
|
||||
# Test basic connectivity
|
||||
health_status["connectivity"] = await get_db_health()
|
||||
|
||||
if not health_status["connectivity"]:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append("Database connectivity failed")
|
||||
return health_status
|
||||
|
||||
# Test table existence
|
||||
tables_verified = await _verify_tables_exist()
|
||||
health_status["tables_exist"] = tables_verified
|
||||
|
||||
if tables_verified:
|
||||
health_status["tables_verified"] = [
|
||||
'model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts'
|
||||
]
|
||||
else:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append("Required tables missing or inaccessible")
|
||||
|
||||
# Try to identify which specific tables are missing
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
for table_name in ['model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts']:
|
||||
try:
|
||||
await session.execute(text(f"SELECT 1 FROM {table_name} LIMIT 1"))
|
||||
health_status["tables_verified"].append(table_name)
|
||||
except Exception:
|
||||
health_status["missing_tables"].append(table_name)
|
||||
except Exception as e:
|
||||
health_status["errors"].append(f"Error checking individual tables: {str(e)}")
|
||||
|
||||
logger.debug("Comprehensive database health check completed",
|
||||
status=health_status["status"],
|
||||
connectivity=health_status["connectivity"],
|
||||
tables_exist=health_status["tables_exist"])
|
||||
|
||||
except Exception as e:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append(f"Health check failed: {str(e)}")
|
||||
logger.error("Comprehensive database health check failed", error=str(e))
|
||||
|
||||
return health_status
|
||||
|
||||
# Training service specific database utilities
|
||||
class TrainingDatabaseUtils:
|
||||
"""Training service specific database utilities"""
|
||||
@@ -223,11 +283,27 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
# Database initialization for training service
|
||||
async def initialize_training_database():
|
||||
"""Initialize database tables for training service"""
|
||||
try:
|
||||
logger.info("Initializing training service database")
|
||||
"""Initialize database tables for training service with retry logic and verification"""
|
||||
import asyncio
|
||||
from sqlalchemy import text
|
||||
|
||||
# Import models to ensure they're registered
|
||||
max_retries = 5
|
||||
retry_delay = 2.0
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
logger.info("Initializing training service database",
|
||||
attempt=attempt, max_retries=max_retries)
|
||||
|
||||
# Step 1: Test database connectivity first
|
||||
logger.info("Testing database connectivity...")
|
||||
connection_ok = await database_manager.test_connection()
|
||||
if not connection_ok:
|
||||
raise Exception("Database connection test failed")
|
||||
logger.info("Database connectivity verified")
|
||||
|
||||
# Step 2: Import models to ensure they're registered
|
||||
logger.info("Importing and registering database models...")
|
||||
from app.models.training import (
|
||||
ModelTrainingLog,
|
||||
TrainedModel,
|
||||
@@ -236,14 +312,89 @@ async def initialize_training_database():
|
||||
ModelArtifact
|
||||
)
|
||||
|
||||
# Create tables using shared infrastructure
|
||||
# Verify models are registered in metadata
|
||||
expected_tables = {
|
||||
'model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts'
|
||||
}
|
||||
registered_tables = set(Base.metadata.tables.keys())
|
||||
missing_tables = expected_tables - registered_tables
|
||||
if missing_tables:
|
||||
raise Exception(f"Models not properly registered: {missing_tables}")
|
||||
|
||||
logger.info("Models registered successfully",
|
||||
tables=list(registered_tables))
|
||||
|
||||
# Step 3: Create tables using shared infrastructure with verification
|
||||
logger.info("Creating database tables...")
|
||||
await database_manager.create_tables()
|
||||
|
||||
logger.info("Training service database initialized successfully")
|
||||
# Step 4: Verify tables were actually created
|
||||
logger.info("Verifying table creation...")
|
||||
verification_successful = await _verify_tables_exist()
|
||||
|
||||
if not verification_successful:
|
||||
raise Exception("Table verification failed - tables were not created properly")
|
||||
|
||||
logger.info("Training service database initialized and verified successfully",
|
||||
attempt=attempt)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize training service database", error=str(e))
|
||||
raise
|
||||
logger.error("Database initialization failed",
|
||||
attempt=attempt,
|
||||
max_retries=max_retries,
|
||||
error=str(e))
|
||||
|
||||
if attempt == max_retries:
|
||||
logger.error("All database initialization attempts failed - giving up")
|
||||
raise Exception(f"Failed to initialize training database after {max_retries} attempts: {str(e)}")
|
||||
|
||||
# Wait before retry with exponential backoff
|
||||
wait_time = retry_delay * (2 ** (attempt - 1))
|
||||
logger.info("Retrying database initialization",
|
||||
retry_in_seconds=wait_time,
|
||||
next_attempt=attempt + 1)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
async def _verify_tables_exist() -> bool:
|
||||
"""Verify that all required tables exist in the database"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Check each required table exists and is accessible
|
||||
required_tables = [
|
||||
'model_training_logs',
|
||||
'trained_models',
|
||||
'model_performance_metrics',
|
||||
'training_job_queue',
|
||||
'model_artifacts'
|
||||
]
|
||||
|
||||
for table_name in required_tables:
|
||||
try:
|
||||
# Try to query the table structure
|
||||
result = await session.execute(
|
||||
text(f"SELECT 1 FROM {table_name} LIMIT 1")
|
||||
)
|
||||
logger.debug(f"Table {table_name} exists and is accessible")
|
||||
except Exception as table_error:
|
||||
# If it's a "relation does not exist" error, table creation failed
|
||||
if "does not exist" in str(table_error).lower():
|
||||
logger.error(f"Table {table_name} does not exist", error=str(table_error))
|
||||
return False
|
||||
# If it's an empty table, that's fine - table exists
|
||||
elif "no data" in str(table_error).lower():
|
||||
logger.debug(f"Table {table_name} exists but is empty (normal)")
|
||||
else:
|
||||
logger.warning(f"Unexpected error querying {table_name}", error=str(table_error))
|
||||
|
||||
logger.info("All required tables verified successfully",
|
||||
tables=required_tables)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Table verification failed", error=str(e))
|
||||
return False
|
||||
|
||||
# Database cleanup for training service
|
||||
async def cleanup_training_database():
|
||||
|
||||
@@ -16,7 +16,7 @@ from fastapi.responses import JSONResponse
|
||||
import uvicorn
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import initialize_training_database, cleanup_training_database, get_db_health
|
||||
from app.core.database import initialize_training_database, cleanup_training_database, get_db_health, get_comprehensive_db_health
|
||||
from app.api import training, models
|
||||
|
||||
from app.api.websocket import websocket_router
|
||||
@@ -195,18 +195,69 @@ async def health_check():
|
||||
|
||||
@app.get("/health/ready")
|
||||
async def readiness_check():
|
||||
"""Kubernetes readiness probe endpoint"""
|
||||
"""Kubernetes readiness probe endpoint with comprehensive database checks"""
|
||||
try:
|
||||
# Get comprehensive database health including table verification
|
||||
db_health = await get_comprehensive_db_health()
|
||||
|
||||
checks = {
|
||||
"database": await get_db_health(),
|
||||
"database_connectivity": db_health["connectivity"],
|
||||
"database_tables": db_health["tables_exist"],
|
||||
"application": getattr(app.state, 'ready', False)
|
||||
}
|
||||
|
||||
if all(checks.values()):
|
||||
return {"status": "ready", "checks": checks}
|
||||
# Include detailed database info for debugging
|
||||
database_details = {
|
||||
"status": db_health["status"],
|
||||
"tables_verified": db_health["tables_verified"],
|
||||
"missing_tables": db_health["missing_tables"],
|
||||
"errors": db_health["errors"]
|
||||
}
|
||||
|
||||
# Service is ready only if all checks pass
|
||||
all_ready = all(checks.values()) and db_health["status"] == "healthy"
|
||||
|
||||
if all_ready:
|
||||
return {
|
||||
"status": "ready",
|
||||
"checks": checks,
|
||||
"database": database_details
|
||||
}
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"status": "not ready", "checks": checks}
|
||||
content={
|
||||
"status": "not ready",
|
||||
"checks": checks,
|
||||
"database": database_details
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Readiness check failed", error=str(e))
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"status": "not ready",
|
||||
"error": f"Health check failed: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
@app.get("/health/database")
|
||||
async def database_health_check():
|
||||
"""Detailed database health endpoint for debugging"""
|
||||
try:
|
||||
db_health = await get_comprehensive_db_health()
|
||||
status_code = 200 if db_health["status"] == "healthy" else 503
|
||||
return JSONResponse(status_code=status_code, content=db_health)
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"status": "unhealthy",
|
||||
"error": f"Health check failed: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
@app.get("/metrics")
|
||||
@@ -220,11 +271,6 @@ async def get_metrics():
|
||||
async def liveness_check():
|
||||
return {"status": "alive"}
|
||||
|
||||
@app.get("/health/ready")
|
||||
async def readiness_check():
|
||||
ready = getattr(app.state, 'ready', True)
|
||||
return {"status": "ready" if ready else "not ready"}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"service": "training-service", "version": "1.0.0"}
|
||||
|
||||
@@ -150,12 +150,33 @@ class DatabaseManager:
|
||||
# ===== TABLE MANAGEMENT =====
|
||||
|
||||
async def create_tables(self, metadata=None):
|
||||
"""Create database tables"""
|
||||
"""Create database tables with enhanced error handling and transaction verification"""
|
||||
try:
|
||||
target_metadata = metadata or Base.metadata
|
||||
table_names = list(target_metadata.tables.keys())
|
||||
logger.info(f"Creating tables: {table_names}", service=self.service_name)
|
||||
|
||||
# Use explicit transaction with proper error handling
|
||||
async with self.async_engine.begin() as conn:
|
||||
try:
|
||||
# Create tables within the transaction
|
||||
await conn.run_sync(target_metadata.create_all, checkfirst=True)
|
||||
|
||||
# Verify transaction is not in error state
|
||||
# Try a simple query to ensure connection is still valid
|
||||
await conn.execute(text("SELECT 1"))
|
||||
|
||||
logger.info("Database tables creation transaction completed successfully",
|
||||
service=self.service_name, tables=table_names)
|
||||
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error during table creation within transaction: {create_error}",
|
||||
service=self.service_name)
|
||||
# Re-raise to trigger transaction rollback
|
||||
raise
|
||||
|
||||
logger.info("Database tables created successfully", service=self.service_name)
|
||||
|
||||
except Exception as e:
|
||||
# Check if it's a "relation already exists" error which can be safely ignored
|
||||
error_str = str(e).lower()
|
||||
@@ -164,6 +185,14 @@ class DatabaseManager:
|
||||
logger.info("Database tables creation completed (some already existed)", service=self.service_name)
|
||||
else:
|
||||
logger.error(f"Failed to create tables: {e}", service=self.service_name)
|
||||
|
||||
# Check for specific transaction error indicators
|
||||
if any(indicator in error_str for indicator in [
|
||||
"transaction", "rollback", "aborted", "failed sql transaction"
|
||||
]):
|
||||
logger.error("Transaction-related error detected during table creation",
|
||||
service=self.service_name)
|
||||
|
||||
raise DatabaseError(f"Table creation failed: {str(e)}")
|
||||
|
||||
async def drop_tables(self, metadata=None):
|
||||
|
||||
Reference in New Issue
Block a user