Add base kubernetes support final fix 4

This commit is contained in:
Urtzi Alfaro
2025-09-29 07:54:25 +02:00
parent 57f77638cc
commit 4777e59e7a
14 changed files with 1041 additions and 167 deletions

View File

@@ -91,7 +91,27 @@ class ApiClient {
// Response interceptor for error handling and automatic token refresh // Response interceptor for error handling and automatic token refresh
this.client.interceptors.response.use( this.client.interceptors.response.use(
(response) => response, (response) => {
// Enhanced logging for token refresh header detection
const refreshSuggested = response.headers['x-token-refresh-suggested'];
if (refreshSuggested) {
console.log('🔍 TOKEN REFRESH HEADER DETECTED:', {
url: response.config?.url,
method: response.config?.method,
status: response.status,
refreshSuggested,
hasRefreshToken: !!this.refreshToken,
currentTokenLength: this.authToken?.length || 0
});
}
// Check if server suggests token refresh
if (refreshSuggested === 'true' && this.refreshToken) {
console.log('🔄 Server suggests token refresh - refreshing proactively');
this.proactiveTokenRefresh();
}
return response;
},
async (error) => { async (error) => {
const originalRequest = error.config; const originalRequest = error.config;
@@ -228,6 +248,40 @@ class ApiClient {
} }
} }
private async proactiveTokenRefresh() {
// Avoid multiple simultaneous proactive refreshes
if (this.isRefreshing) {
return;
}
try {
this.isRefreshing = true;
console.log('🔄 Proactively refreshing token...');
const response = await this.client.post('/auth/refresh', {
refresh_token: this.refreshToken
});
const { access_token, refresh_token } = response.data;
// Update tokens
this.setAuthToken(access_token);
if (refresh_token) {
this.setRefreshToken(refresh_token);
}
// Update auth store
await this.updateAuthStore(access_token, refresh_token);
console.log('✅ Proactive token refresh successful');
} catch (error) {
console.warn('⚠️ Proactive token refresh failed:', error);
// Don't handle as auth failure here - let the next 401 handle it
} finally {
this.isRefreshing = false;
}
}
private async handleAuthFailure() { private async handleAuthFailure() {
try { try {
// Clear tokens // Clear tokens
@@ -275,6 +329,117 @@ class ApiClient {
return this.tenantId; return this.tenantId;
} }
// Token synchronization methods for WebSocket connections
getCurrentValidToken(): string | null {
return this.authToken;
}
async ensureValidToken(): Promise<string | null> {
const originalToken = this.authToken;
const originalTokenShort = originalToken ? `${originalToken.slice(0, 20)}...${originalToken.slice(-10)}` : 'null';
console.log('🔍 ensureValidToken() called:', {
hasToken: !!this.authToken,
tokenPreview: originalTokenShort,
isRefreshing: this.isRefreshing,
hasRefreshToken: !!this.refreshToken
});
// If we have a valid token, return it
if (this.authToken && !this.isTokenNearExpiry(this.authToken)) {
const expiryInfo = this.getTokenExpiryInfo(this.authToken);
console.log('✅ Token is valid, returning current token:', {
tokenPreview: originalTokenShort,
expiryInfo
});
return this.authToken;
}
// If token is near expiry or expired, try to refresh
if (this.refreshToken && !this.isRefreshing) {
console.log('🔄 Token needs refresh, attempting proactive refresh:', {
reason: this.authToken ? 'near expiry' : 'no token',
expiryInfo: this.authToken ? this.getTokenExpiryInfo(this.authToken) : 'N/A'
});
try {
await this.proactiveTokenRefresh();
const newTokenShort = this.authToken ? `${this.authToken.slice(0, 20)}...${this.authToken.slice(-10)}` : 'null';
const tokenChanged = originalToken !== this.authToken;
console.log('✅ Token refresh completed:', {
tokenChanged,
oldTokenPreview: originalTokenShort,
newTokenPreview: newTokenShort,
newExpiryInfo: this.authToken ? this.getTokenExpiryInfo(this.authToken) : 'N/A'
});
return this.authToken;
} catch (error) {
console.warn('❌ Failed to refresh token in ensureValidToken:', error);
return null;
}
}
console.log('⚠️ Returning current token without refresh:', {
reason: this.isRefreshing ? 'already refreshing' : 'no refresh token',
tokenPreview: originalTokenShort
});
return this.authToken;
}
private getTokenExpiryInfo(token: string): any {
try {
const payload = JSON.parse(atob(token.split('.')[1]));
const exp = payload.exp;
const iat = payload.iat;
if (!exp) return { error: 'No expiry in token' };
const now = Math.floor(Date.now() / 1000);
const timeUntilExpiry = exp - now;
const tokenLifetime = exp - iat;
return {
issuedAt: new Date(iat * 1000).toISOString(),
expiresAt: new Date(exp * 1000).toISOString(),
lifetimeMinutes: Math.floor(tokenLifetime / 60),
secondsUntilExpiry: timeUntilExpiry,
minutesUntilExpiry: Math.floor(timeUntilExpiry / 60),
isNearExpiry: timeUntilExpiry < 300,
isExpired: timeUntilExpiry <= 0
};
} catch (error) {
return { error: 'Failed to parse token', details: error };
}
}
private isTokenNearExpiry(token: string): boolean {
try {
const payload = JSON.parse(atob(token.split('.')[1]));
const exp = payload.exp;
if (!exp) return false;
const now = Math.floor(Date.now() / 1000);
const timeUntilExpiry = exp - now;
// Consider token near expiry if less than 5 minutes remaining
const isNear = timeUntilExpiry < 300;
if (isNear) {
console.log('⏰ Token is near expiry:', {
secondsUntilExpiry: timeUntilExpiry,
minutesUntilExpiry: Math.floor(timeUntilExpiry / 60),
expiresAt: new Date(exp * 1000).toISOString()
});
}
return isNear;
} catch (error) {
console.warn('Failed to parse token for expiry check:', error);
return true; // Assume expired if we can't parse
}
}
// HTTP Methods - Return direct data for React Query // HTTP Methods - Return direct data for React Query
async get<T = any>(url: string, config?: AxiosRequestConfig): Promise<T> { async get<T = any>(url: string, config?: AxiosRequestConfig): Promise<T> {
const response: AxiosResponse<T> = await this.client.get(url, config); const response: AxiosResponse<T> = await this.client.get(url, config);

View File

@@ -90,24 +90,35 @@ export const useUpdateStep = (
export const useMarkStepCompleted = ( export const useMarkStepCompleted = (
options?: UseMutationOptions< options?: UseMutationOptions<
UserProgress, UserProgress,
ApiError, ApiError,
{ userId: string; stepName: string; data?: Record<string, any> } { userId: string; stepName: string; data?: Record<string, any> }
> >
) => { ) => {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
return useMutation< return useMutation<
UserProgress, UserProgress,
ApiError, ApiError,
{ userId: string; stepName: string; data?: Record<string, any> } { userId: string; stepName: string; data?: Record<string, any> }
>({ >({
mutationFn: ({ userId, stepName, data }) => mutationFn: ({ userId, stepName, data }) =>
onboardingService.markStepCompleted(userId, stepName, data), onboardingService.markStepCompleted(userId, stepName, data),
onSuccess: (data, { userId }) => { onSuccess: (data, { userId }) => {
// Update progress cache // Update progress cache with new data
queryClient.setQueryData(onboardingKeys.progress(userId), data); queryClient.setQueryData(onboardingKeys.progress(userId), data);
// Invalidate the query to ensure fresh data on next access
queryClient.invalidateQueries({ queryKey: onboardingKeys.progress(userId) });
}, },
onError: (error, { userId, stepName }) => {
console.error(`Failed to complete step ${stepName} for user ${userId}:`, error);
// Invalidate queries on error to ensure we get fresh data
queryClient.invalidateQueries({ queryKey: onboardingKeys.progress(userId) });
},
// Prevent duplicate requests by using the step name as a mutation key
mutationKey: (variables) => ['markStepCompleted', variables?.userId, variables?.stepName],
...options, ...options,
}); });
}; };

View File

@@ -3,10 +3,10 @@
* Provides data fetching, caching, and state management for training operations * Provides data fetching, caching, and state management for training operations
*/ */
import React from 'react'; import * as React from 'react';
import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOptions } from '@tanstack/react-query'; import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOptions } from '@tanstack/react-query';
import { trainingService } from '../services/training'; import { trainingService } from '../services/training';
import { ApiError } from '../client/apiClient'; import { ApiError, apiClient } from '../client/apiClient';
import { useAuthStore } from '../../stores/auth.store'; import { useAuthStore } from '../../stores/auth.store';
import type { import type {
TrainingJobRequest, TrainingJobRequest,
@@ -53,15 +53,62 @@ export const trainingKeys = {
export const useTrainingJobStatus = ( export const useTrainingJobStatus = (
tenantId: string, tenantId: string,
jobId: string, jobId: string,
options?: Omit<UseQueryOptions<TrainingJobStatus, ApiError>, 'queryKey' | 'queryFn'> options?: Omit<UseQueryOptions<TrainingJobStatus, ApiError>, 'queryKey' | 'queryFn'> & {
isWebSocketConnected?: boolean;
}
) => { ) => {
const { isWebSocketConnected, ...queryOptions } = options || {};
// Completely disable the query when WebSocket is connected
const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected;
console.log('🔄 Training status query:', {
tenantId: !!tenantId,
jobId: !!jobId,
isWebSocketConnected,
queryEnabled: isEnabled
});
return useQuery<TrainingJobStatus, ApiError>({ return useQuery<TrainingJobStatus, ApiError>({
queryKey: trainingKeys.jobs.status(tenantId, jobId), queryKey: trainingKeys.jobs.status(tenantId, jobId),
queryFn: () => trainingService.getTrainingJobStatus(tenantId, jobId), queryFn: () => {
enabled: !!tenantId && !!jobId, console.log('📡 Executing HTTP training status query (WebSocket disconnected)');
refetchInterval: 5000, // Poll every 5 seconds while training return trainingService.getTrainingJobStatus(tenantId, jobId);
},
enabled: isEnabled, // Completely disable when WebSocket connected
refetchInterval: (query) => {
// CRITICAL FIX: React Query executes refetchInterval even when enabled=false
// We must check WebSocket connection state here to prevent misleading polling
if (isWebSocketConnected) {
console.log('✅ WebSocket connected - HTTP polling DISABLED');
return false; // Disable polling when WebSocket is active
}
const data = query.state.data;
// Stop polling if we get auth errors or training is completed
if (query.state.error && (query.state.error as any)?.status === 401) {
console.log('🚫 Stopping status polling due to auth error');
return false;
}
if (data?.status === 'completed' || data?.status === 'failed') {
console.log('🏁 Training completed - stopping HTTP polling');
return false; // Stop polling when training is done
}
console.log('📊 HTTP fallback polling active (WebSocket actually disconnected) - 5s interval');
return 5000; // Poll every 5 seconds while training (fallback when WebSocket unavailable)
},
staleTime: 1000, // Consider data stale after 1 second staleTime: 1000, // Consider data stale after 1 second
...options, retry: (failureCount, error) => {
// Don't retry on auth errors
if ((error as any)?.status === 401) {
console.log('🚫 Not retrying due to auth error');
return false;
}
return failureCount < 3;
},
...queryOptions,
}); });
}; };
@@ -242,9 +289,9 @@ export const useTrainingWebSocket = (
} }
) => { ) => {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const authToken = useAuthStore((state) => state.token);
const [isConnected, setIsConnected] = React.useState(false); const [isConnected, setIsConnected] = React.useState(false);
const [connectionError, setConnectionError] = React.useState<string | null>(null); const [connectionError, setConnectionError] = React.useState<string | null>(null);
const [connectionAttempts, setConnectionAttempts] = React.useState(0);
// Memoize options to prevent unnecessary effect re-runs // Memoize options to prevent unnecessary effect re-runs
const memoizedOptions = React.useMemo(() => options, [ const memoizedOptions = React.useMemo(() => options, [
@@ -266,20 +313,44 @@ export const useTrainingWebSocket = (
let reconnectAttempts = 0; let reconnectAttempts = 0;
const maxReconnectAttempts = 3; const maxReconnectAttempts = 3;
const connect = () => { const connect = async () => {
try { try {
setConnectionError(null); setConnectionError(null);
const effectiveToken = token || authToken; setConnectionAttempts(prev => prev + 1);
// Use centralized token management from apiClient
let effectiveToken: string | null;
try {
// Always use the apiClient's token management
effectiveToken = await apiClient.ensureValidToken();
if (!effectiveToken) {
throw new Error('No valid token available');
}
} catch (error) {
console.error('❌ Failed to get valid token for WebSocket:', error);
setConnectionError('Authentication failed. Please log in again.');
return;
}
console.log(`🔄 Attempting WebSocket connection (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts + 1}):`, { console.log(`🔄 Attempting WebSocket connection (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts + 1}):`, {
tenantId, tenantId,
jobId, jobId,
hasToken: !!effectiveToken hasToken: !!effectiveToken,
tokenFromApiClient: true
}); });
ws = trainingService.createWebSocketConnection(tenantId, jobId, token || authToken || undefined); ws = trainingService.createWebSocketConnection(tenantId, jobId, effectiveToken);
ws.onopen = () => { ws.onopen = () => {
console.log('✅ Training WebSocket connected successfully'); console.log('✅ Training WebSocket connected successfully', {
readyState: ws?.readyState,
url: ws?.url,
jobId
});
// Track connection time for debugging
(ws as any)._connectTime = Date.now();
setIsConnected(true); setIsConnected(true);
reconnectAttempts = 0; // Reset on successful connection reconnectAttempts = 0; // Reset on successful connection
@@ -291,23 +362,81 @@ export const useTrainingWebSocket = (
console.warn('Failed to request status on connection:', e); console.warn('Failed to request status on connection:', e);
} }
// Set up periodic ping to keep connection alive // Helper function to check if tokens represent different auth sessions
const pingInterval = setInterval(() => { const isTokenSessionDifferent = (oldToken: string, newToken: string): boolean => {
if (!oldToken || !newToken) return !!oldToken !== !!newToken;
try {
const oldPayload = JSON.parse(atob(oldToken.split('.')[1]));
const newPayload = JSON.parse(atob(newToken.split('.')[1]));
// Compare by issued timestamp (iat) - different iat means new auth session
return oldPayload.iat !== newPayload.iat;
} catch (e) {
console.warn('Failed to parse token for session comparison, falling back to string comparison:', e);
return oldToken !== newToken;
}
};
// Set up periodic ping and intelligent token refresh detection
const heartbeatInterval = setInterval(async () => {
if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) { if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) {
try { try {
// Check token validity (this may refresh if needed)
const currentToken = await apiClient.ensureValidToken();
// Enhanced token change detection with detailed logging
const tokenStringChanged = currentToken !== effectiveToken;
const tokenSessionChanged = currentToken && effectiveToken ?
isTokenSessionDifferent(effectiveToken, currentToken) : tokenStringChanged;
console.log('🔍 WebSocket token validation check:', {
hasCurrentToken: !!currentToken,
hasEffectiveToken: !!effectiveToken,
tokenStringChanged,
tokenSessionChanged,
currentTokenPreview: currentToken ? `${currentToken.slice(0, 20)}...${currentToken.slice(-10)}` : 'null',
effectiveTokenPreview: effectiveToken ? `${effectiveToken.slice(0, 20)}...${effectiveToken.slice(-10)}` : 'null'
});
// Only reconnect if we have a genuine session change (different iat)
if (tokenSessionChanged) {
console.log('🔄 Token session changed - reconnecting WebSocket with new session token');
console.log('📊 Session change details:', {
reason: !currentToken ? 'token removed' :
!effectiveToken ? 'token added' : 'new auth session',
oldTokenIat: effectiveToken ? (() => {
try { return JSON.parse(atob(effectiveToken.split('.')[1])).iat; } catch { return 'parse-error'; }
})() : 'N/A',
newTokenIat: currentToken ? (() => {
try { return JSON.parse(atob(currentToken.split('.')[1])).iat; } catch { return 'parse-error'; }
})() : 'N/A'
});
// Close current connection and trigger reconnection with new token
ws?.close(1000, 'Token session changed - reconnecting');
clearInterval(heartbeatInterval);
return;
} else if (tokenStringChanged) {
console.log(' Token string changed but same session - continuing with current connection');
// Update effective token reference for future comparisons
effectiveToken = currentToken;
}
console.log('✅ Token validated during heartbeat - same session');
ws?.send('ping'); ws?.send('ping');
console.log('💓 Sent ping to server'); console.log('💓 Sent ping to server (token session validated)');
} catch (e) { } catch (e) {
console.warn('Failed to send ping:', e); console.warn('Failed to send ping or validate token:', e);
clearInterval(pingInterval); clearInterval(heartbeatInterval);
} }
} else { } else {
clearInterval(pingInterval); clearInterval(heartbeatInterval);
} }
}, 30000); // Ping every 30 seconds }, 30000); // Check every 30 seconds for token refresh and send ping
// Store interval for cleanup // Store interval for cleanup
(ws as any).pingInterval = pingInterval; (ws as any).heartbeatInterval = heartbeatInterval;
}; };
ws.onmessage = (event) => { ws.onmessage = (event) => {
@@ -404,13 +533,22 @@ export const useTrainingWebSocket = (
}; };
ws.onclose = (event) => { ws.onclose = (event) => {
console.log(`❌ Training WebSocket disconnected. Code: ${event.code}, Reason: "${event.reason}"`); console.log(`❌ Training WebSocket disconnected. Code: ${event.code}, Reason: "${event.reason}"`, {
wasClean: event.wasClean,
jobId,
timeConnected: ws ? `${Date.now() - (ws as any)._connectTime || 0}ms` : 'unknown',
reconnectAttempts
});
setIsConnected(false); setIsConnected(false);
// Detailed logging for different close codes // Detailed logging for different close codes
switch (event.code) { switch (event.code) {
case 1000: case 1000:
console.log('🔒 WebSocket closed normally'); if (event.reason === 'Token refreshed - reconnecting') {
console.log('🔄 WebSocket closed for token refresh - will reconnect immediately');
} else {
console.log('🔒 WebSocket closed normally');
}
break; break;
case 1006: case 1006:
console.log('⚠️ WebSocket closed abnormally (1006) - likely server-side issue or network problem'); console.log('⚠️ WebSocket closed abnormally (1006) - likely server-side issue or network problem');
@@ -425,11 +563,20 @@ export const useTrainingWebSocket = (
console.log(`❓ WebSocket closed with code ${event.code}`); console.log(`❓ WebSocket closed with code ${event.code}`);
} }
// Handle token refresh reconnection (immediate reconnect)
if (event.code === 1000 && event.reason === 'Token refreshed - reconnecting') {
console.log('🔄 Reconnecting immediately due to token refresh...');
reconnectTimer = setTimeout(() => {
connect(); // Reconnect immediately with fresh token
}, 1000); // Short delay to allow cleanup
return;
}
// Try to reconnect if not manually disconnected and haven't exceeded max attempts // Try to reconnect if not manually disconnected and haven't exceeded max attempts
if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) { if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) {
const delay = Math.min(1000 * Math.pow(2, reconnectAttempts), 10000); // Exponential backoff, max 10s const delay = Math.min(1000 * Math.pow(2, reconnectAttempts), 10000); // Exponential backoff, max 10s
console.log(`🔄 Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`); console.log(`🔄 Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`);
reconnectTimer = setTimeout(() => { reconnectTimer = setTimeout(() => {
reconnectAttempts++; reconnectAttempts++;
connect(); connect();
@@ -470,7 +617,7 @@ export const useTrainingWebSocket = (
setIsConnected(false); setIsConnected(false);
}; };
}, [tenantId, jobId, token, authToken, queryClient, memoizedOptions]); }, [tenantId, jobId, queryClient, memoizedOptions]);
return { return {
isConnected, isConnected,
@@ -479,17 +626,27 @@ export const useTrainingWebSocket = (
}; };
// Utility Hooks // Utility Hooks
export const useIsTrainingInProgress = (tenantId: string, jobId?: string) => { export const useIsTrainingInProgress = (
tenantId: string,
jobId?: string,
isWebSocketConnected?: boolean
) => {
const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', { const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', {
enabled: !!jobId, enabled: !!jobId,
isWebSocketConnected,
}); });
return jobStatus?.status === 'running' || jobStatus?.status === 'pending'; return jobStatus?.status === 'running' || jobStatus?.status === 'pending';
}; };
export const useTrainingProgress = (tenantId: string, jobId?: string) => { export const useTrainingProgress = (
tenantId: string,
jobId?: string,
isWebSocketConnected?: boolean
) => {
const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', { const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', {
enabled: !!jobId, enabled: !!jobId,
isWebSocketConnected,
}); });
return { return {

View File

@@ -94,18 +94,21 @@ export const OnboardingWizard: React.FC = () => {
const { setCurrentTenant } = useTenantActions(); const { setCurrentTenant } = useTenantActions();
// Auto-complete user_registered step if needed (runs first) // Auto-complete user_registered step if needed (runs first)
const [autoCompletionAttempted, setAutoCompletionAttempted] = React.useState(false);
useEffect(() => { useEffect(() => {
if (userProgress && user?.id) { if (userProgress && user?.id && !autoCompletionAttempted && !markStepCompleted.isPending) {
const userRegisteredStep = userProgress.steps.find(s => s.step_name === 'user_registered'); const userRegisteredStep = userProgress.steps.find(s => s.step_name === 'user_registered');
if (!userRegisteredStep?.completed) { if (!userRegisteredStep?.completed) {
console.log('🔄 Auto-completing user_registered step for new user...'); console.log('🔄 Auto-completing user_registered step for new user...');
setAutoCompletionAttempted(true);
markStepCompleted.mutate({ markStepCompleted.mutate({
userId: user.id, userId: user.id,
stepName: 'user_registered', stepName: 'user_registered',
data: { data: {
auto_completed: true, auto_completed: true,
completed_at: new Date().toISOString(), completed_at: new Date().toISOString(),
source: 'onboarding_wizard_auto_completion' source: 'onboarding_wizard_auto_completion'
} }
@@ -116,11 +119,13 @@ export const OnboardingWizard: React.FC = () => {
}, },
onError: (error) => { onError: (error) => {
console.error('❌ Failed to auto-complete user_registered step:', error); console.error('❌ Failed to auto-complete user_registered step:', error);
// Reset flag on error to allow retry
setAutoCompletionAttempted(false);
} }
}); });
} }
} }
}, [userProgress, user?.id, markStepCompleted]); }, [userProgress, user?.id, autoCompletionAttempted, markStepCompleted.isPending]); // Removed markStepCompleted from deps
// Initialize step index based on backend progress with validation // Initialize step index based on backend progress with validation
useEffect(() => { useEffect(() => {
@@ -205,6 +210,12 @@ export const OnboardingWizard: React.FC = () => {
return; return;
} }
// Prevent concurrent mutations
if (markStepCompleted.isPending) {
console.warn(`⚠️ Step completion already in progress for "${currentStep.id}", skipping duplicate call`);
return;
}
console.log(`🎯 Completing step: "${currentStep.id}" with data:`, data); console.log(`🎯 Completing step: "${currentStep.id}" with data:`, data);
try { try {
@@ -260,25 +271,50 @@ export const OnboardingWizard: React.FC = () => {
} }
} catch (error: any) { } catch (error: any) {
console.error(`❌ Error completing step "${currentStep.id}":`, error); console.error(`❌ Error completing step "${currentStep.id}":`, error);
// Extract detailed error information // Extract detailed error information
const errorMessage = error?.response?.data?.detail || error?.message || 'Unknown error'; const errorMessage = error?.response?.data?.detail || error?.message || 'Unknown error';
const statusCode = error?.response?.status; const statusCode = error?.response?.status;
console.error(`📊 Error details: Status ${statusCode}, Message: ${errorMessage}`); console.error(`📊 Error details: Status ${statusCode}, Message: ${errorMessage}`);
// Handle different types of errors
if (statusCode === 207) {
// Multi-Status: Step updated but summary failed
console.warn(`⚠️ Partial success for step "${currentStep.id}": ${errorMessage}`);
// Continue with step advancement since the actual step was completed
if (currentStep.id === 'completion') {
// Navigate to dashboard after completion
if (isNewTenant) {
navigate('/app/dashboard');
} else {
navigate('/app');
}
} else {
// Auto-advance to next step after successful completion
if (currentStepIndex < STEPS.length - 1) {
setCurrentStepIndex(currentStepIndex + 1);
}
}
// Show a warning but don't block progress
console.warn(`Step "${currentStep.title}" completed with warnings: ${errorMessage}`);
return; // Don't show error alert
}
// Check if it's a dependency error // Check if it's a dependency error
if (errorMessage.includes('dependencies not met')) { if (errorMessage.includes('dependencies not met')) {
console.error('🚫 Dependencies not met for step:', currentStep.id); console.error('🚫 Dependencies not met for step:', currentStep.id);
// Check what dependencies are missing // Check what dependencies are missing
if (userProgress) { if (userProgress) {
console.log('📋 Current progress:', userProgress); console.log('📋 Current progress:', userProgress);
console.log('📋 Completed steps:', userProgress.steps.filter(s => s.completed).map(s => s.step_name)); console.log('📋 Completed steps:', userProgress.steps.filter(s => s.completed).map(s => s.step_name));
} }
} }
// Don't advance automatically on error - user should see the issue // Don't advance automatically on real errors - user should see the issue
alert(`${t('onboarding:errors.step_failed', 'Error al completar paso')} "${currentStep.title}": ${errorMessage}`); alert(`${t('onboarding:errors.step_failed', 'Error al completar paso')} "${currentStep.title}": ${errorMessage}`);
} }
}; };

View File

@@ -1,7 +1,7 @@
import React, { useState, useCallback, useEffect } from 'react'; import React, { useState, useCallback, useEffect } from 'react';
import { Button } from '../../../ui/Button'; import { Button } from '../../../ui/Button';
import { useCurrentTenant } from '../../../../stores/tenant.store'; import { useCurrentTenant } from '../../../../stores/tenant.store';
import { useCreateTrainingJob, useTrainingWebSocket } from '../../../../api/hooks/training'; import { useCreateTrainingJob, useTrainingWebSocket, useTrainingJobStatus } from '../../../../api/hooks/training';
interface MLTrainingStepProps { interface MLTrainingStepProps {
onNext: () => void; onNext: () => void;
@@ -85,6 +85,63 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
} : undefined } : undefined
); );
// Smart fallback polling - automatically disabled when WebSocket is connected
const { data: jobStatus } = useTrainingJobStatus(
currentTenant?.id || '',
jobId || '',
{
enabled: !!jobId && !!currentTenant?.id,
isWebSocketConnected: isConnected, // This will disable HTTP polling when WebSocket is connected
}
);
// Handle training status updates from HTTP polling (fallback only)
useEffect(() => {
if (!jobStatus || !jobId || trainingProgress?.stage === 'completed') {
return;
}
console.log('📊 HTTP fallback status update:', jobStatus);
// Check if training completed via HTTP polling fallback
if (jobStatus.status === 'completed' && trainingProgress?.stage !== 'completed') {
console.log('✅ Training completion detected via HTTP fallback');
setTrainingProgress({
stage: 'completed',
progress: 100,
message: 'Entrenamiento completado exitosamente (detectado por verificación HTTP)'
});
setIsTraining(false);
setTimeout(() => {
onComplete({
jobId: jobId,
success: true,
message: 'Modelo entrenado correctamente',
detectedViaPolling: true
});
}, 2000);
} else if (jobStatus.status === 'failed') {
console.log('❌ Training failure detected via HTTP fallback');
setError('Error detectado durante el entrenamiento (verificación de estado)');
setIsTraining(false);
setTrainingProgress(null);
} else if (jobStatus.status === 'running' && jobStatus.progress !== undefined) {
// Update progress if we have newer information from HTTP polling fallback
const currentProgress = trainingProgress?.progress || 0;
if (jobStatus.progress > currentProgress) {
console.log(`📈 Progress update via HTTP fallback: ${jobStatus.progress}%`);
setTrainingProgress(prev => ({
...prev,
stage: 'training',
progress: jobStatus.progress,
message: jobStatus.message || 'Entrenando modelo...',
currentStep: jobStatus.current_step
}) as TrainingProgress);
}
}
}, [jobStatus, jobId, trainingProgress?.stage, onComplete]);
// Auto-trigger training when component mounts // Auto-trigger training when component mounts
useEffect(() => { useEffect(() => {
if (currentTenant?.id && !isTraining && !trainingProgress && !error) { if (currentTenant?.id && !isTraining && !trainingProgress && !error) {

View File

@@ -249,60 +249,193 @@ async def events_stream(request: Request, token: str):
@app.websocket("/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live") @app.websocket("/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str): async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
"""WebSocket proxy that forwards connections directly to training service""" """WebSocket proxy that forwards connections directly to training service with enhanced token validation"""
await websocket.accept() await websocket.accept()
# Get token from query params # Get token from query params
token = websocket.query_params.get("token") token = websocket.query_params.get("token")
if not token: if not token:
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}") logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
await websocket.close(code=1008, reason="Authentication token required") await websocket.close(code=1008, reason="Authentication token required")
return return
# Validate token using auth middleware
from app.middleware.auth import jwt_handler
try:
payload = jwt_handler.verify_token(token)
if not payload:
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
await websocket.close(code=1008, reason="Invalid authentication token")
return
# Check token expiration
import time
if payload.get('exp', 0) < time.time():
logger.warning(f"WebSocket connection rejected - expired token for job {job_id}")
await websocket.close(code=1008, reason="Token expired")
return
logger.info(f"WebSocket token validated for user {payload.get('email', 'unknown')}")
except Exception as e:
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
await websocket.close(code=1008, reason="Token validation failed")
return
logger.info(f"Proxying WebSocket connection to training service for job {job_id}, tenant {tenant_id}") logger.info(f"Proxying WebSocket connection to training service for job {job_id}, tenant {tenant_id}")
# Build WebSocket URL to training service # Build WebSocket URL to training service
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/') training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://') training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
training_ws_url = f"{training_ws_url}/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}" training_ws_url = f"{training_ws_url}/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
training_ws = None
heartbeat_task = None
try: try:
# Connect to training service WebSocket # Connect to training service WebSocket with proper timeout configuration
import websockets import websockets
async with websockets.connect(training_ws_url) as training_ws:
logger.info(f"Connected to training service WebSocket for job {job_id}") # Configure timeouts to coordinate with frontend (30s heartbeat) and training service
# DISABLE gateway-level ping to avoid dual-ping conflicts - let frontend handle ping/pong
async def forward_to_training(): training_ws = await websockets.connect(
"""Forward messages from frontend to training service""" training_ws_url,
ping_interval=None, # DISABLED: Let frontend handle ping/pong via message forwarding
ping_timeout=None, # DISABLED: No independent ping mechanism
close_timeout=15, # Reasonable close timeout
max_size=2**20, # 1MB max message size
max_queue=32 # Max queued messages
)
logger.info(f"Connected to training service WebSocket for job {job_id} with gateway ping DISABLED (frontend handles ping/pong)")
# Track connection state properly due to FastAPI WebSocket state propagation bug
connection_alive = True
last_activity = asyncio.get_event_loop().time()
async def check_connection_health():
"""Monitor connection health based on activity timestamps only - no WebSocket interference"""
nonlocal connection_alive, last_activity
while connection_alive:
try: try:
async for message in websocket.iter_text(): await asyncio.sleep(30) # Check every 30 seconds (aligned with frontend heartbeat)
current_time = asyncio.get_event_loop().time()
# Check if we haven't received any activity for too long
# Frontend sends ping every 30s, so 90s = 3 missed pings before considering dead
if current_time - last_activity > 90:
logger.warning(f"No frontend activity for 90s on job {job_id} - connection may be dead")
# Don't forcibly close - let the forwarding loops handle actual connection issues
# This is just monitoring/logging now
else:
logger.debug(f"Connection health OK for job {job_id} - last activity {int(current_time - last_activity)}s ago")
except Exception as e:
logger.error(f"Connection health monitoring error for job {job_id}: {e}")
break
async def forward_to_training():
"""Forward messages from frontend to training service with proper error handling"""
nonlocal connection_alive, last_activity
try:
while connection_alive and training_ws and training_ws.open:
try:
# Use longer timeout to avoid conflicts with frontend 30s heartbeat
# Frontend sends ping every 30s, so we need to allow for some latency
message = await asyncio.wait_for(websocket.receive_text(), timeout=45.0)
last_activity = asyncio.get_event_loop().time()
# Forward the message to training service
await training_ws.send(message) await training_ws.send(message)
except Exception as e: logger.debug(f"Forwarded message to training service for job {job_id}: {message[:100]}...")
logger.error(f"Error forwarding to training service: {e}")
except asyncio.TimeoutError:
async def forward_to_frontend(): # No message received in 45 seconds, continue loop
"""Forward messages from training service to frontend""" # This allows for frontend 30s heartbeat + network latency + processing time
try: continue
async for message in training_ws: except Exception as e:
logger.error(f"Error receiving from frontend for job {job_id}: {e}")
connection_alive = False
break
except Exception as e:
logger.error(f"Error in forward_to_training for job {job_id}: {e}")
connection_alive = False
async def forward_to_frontend():
"""Forward messages from training service to frontend with proper error handling"""
nonlocal connection_alive, last_activity
try:
while connection_alive and training_ws and training_ws.open:
try:
# Use coordinated timeout - training service expects messages every 60s
# This should be longer than training service timeout to avoid premature closure
message = await asyncio.wait_for(training_ws.recv(), timeout=75.0)
last_activity = asyncio.get_event_loop().time()
# Forward the message to frontend
await websocket.send_text(message) await websocket.send_text(message)
except Exception as e: logger.debug(f"Forwarded message to frontend for job {job_id}: {message[:100]}...")
logger.error(f"Error forwarding to frontend: {e}")
except asyncio.TimeoutError:
# Run both forwarding tasks concurrently # No message received in 75 seconds, continue loop
# Training service sends heartbeats, so this indicates potential issues
continue
except Exception as e:
logger.error(f"Error receiving from training service for job {job_id}: {e}")
connection_alive = False
break
except Exception as e:
logger.error(f"Error in forward_to_frontend for job {job_id}: {e}")
connection_alive = False
# Start connection health monitoring
heartbeat_task = asyncio.create_task(check_connection_health())
# Run both forwarding tasks concurrently with proper error handling
try:
await asyncio.gather( await asyncio.gather(
forward_to_training(), forward_to_training(),
forward_to_frontend(), forward_to_frontend(),
return_exceptions=True return_exceptions=True
) )
except Exception as e:
logger.error(f"Error in WebSocket forwarding tasks for job {job_id}: {e}")
finally:
connection_alive = False
except websockets.exceptions.ConnectionClosedError as e:
logger.warning(f"Training service WebSocket connection closed for job {job_id}: {e}")
except websockets.exceptions.WebSocketException as e:
logger.error(f"WebSocket exception for job {job_id}: {e}")
except Exception as e: except Exception as e:
logger.error(f"WebSocket proxy error for job {job_id}: {e}") logger.error(f"WebSocket proxy error for job {job_id}: {e}")
try:
await websocket.close(code=1011, reason="Training service connection failed")
except:
pass
finally: finally:
logger.info(f"WebSocket proxy closed for job {job_id}") # Cleanup
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
if training_ws and not training_ws.closed:
try:
await training_ws.close()
except Exception as e:
logger.warning(f"Error closing training service WebSocket for job {job_id}: {e}")
try:
if not websocket.client_state.name == 'DISCONNECTED':
await websocket.close(code=1000, reason="Proxy connection closed")
except Exception as e:
logger.warning(f"Error closing frontend WebSocket for job {job_id}: {e}")
logger.info(f"WebSocket proxy cleanup completed for job {job_id}")
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

View File

@@ -67,7 +67,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
) )
# ✅ STEP 2: Verify token and get user context # ✅ STEP 2: Verify token and get user context
user_context = await self._verify_token(token) user_context = await self._verify_token(token, request)
if not user_context: if not user_context:
logger.warning(f"Invalid token for route: {request.url.path}") logger.warning(f"Invalid token for route: {request.url.path}")
return JSONResponse( return JSONResponse(
@@ -117,7 +117,14 @@ class AuthMiddleware(BaseHTTPMiddleware):
tenant_id=tenant_id, tenant_id=tenant_id,
path=request.url.path) path=request.url.path)
return await call_next(request) # Process the request
response = await call_next(request)
# Add token expiry warning header if token is near expiry
if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry:
response.headers["X-Token-Refresh-Suggested"] = "true"
return response
def _is_public_route(self, path: str) -> bool: def _is_public_route(self, path: str) -> bool:
"""Check if route requires authentication""" """Check if route requires authentication"""
@@ -130,7 +137,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
return auth_header.split(" ")[1] return auth_header.split(" ")[1]
return None return None
async def _verify_token(self, token: str) -> Optional[Dict[str, Any]]: async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]:
""" """
Verify JWT token with improved fallback strategy Verify JWT token with improved fallback strategy
FIXED: Better error handling and token structure validation FIXED: Better error handling and token structure validation
@@ -141,6 +148,17 @@ class AuthMiddleware(BaseHTTPMiddleware):
payload = jwt_handler.verify_token(token) payload = jwt_handler.verify_token(token)
if payload and self._validate_token_payload(payload): if payload and self._validate_token_payload(payload):
logger.debug("Token validated locally") logger.debug("Token validated locally")
# Check if token is near expiry and set flag for response header
if request:
import time
exp_time = payload.get("exp", 0)
current_time = time.time()
time_until_expiry = exp_time - current_time
if time_until_expiry < 300: # 5 minutes
request.state.token_near_expiry = True
# Convert JWT payload to user context format # Convert JWT payload to user context format
return self._jwt_payload_to_user_context(payload) return self._jwt_payload_to_user_context(payload)
except Exception as e: except Exception as e:
@@ -177,18 +195,26 @@ class AuthMiddleware(BaseHTTPMiddleware):
""" """
required_fields = ["user_id", "email", "exp", "type"] required_fields = ["user_id", "email", "exp", "type"]
missing_fields = [field for field in required_fields if field not in payload] missing_fields = [field for field in required_fields if field not in payload]
if missing_fields: if missing_fields:
logger.warning(f"Token payload missing fields: {missing_fields}") logger.warning(f"Token payload missing fields: {missing_fields}")
return False return False
# Validate token type # Validate token type
token_type = payload.get("type") token_type = payload.get("type")
if token_type not in ["access", "service"]: if token_type not in ["access", "service"]:
logger.warning(f"Invalid token type: {payload.get('type')}") logger.warning(f"Invalid token type: {payload.get('type')}")
return False return False
# Check if token is near expiry (within 5 minutes) and log warning
import time
exp_time = payload.get("exp", 0)
current_time = time.time()
time_until_expiry = exp_time - current_time
if time_until_expiry < 300: # 5 minutes
logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}")
return True return True
def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]: def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -88,7 +88,7 @@ data:
# AUTHENTICATION & SECURITY SETTINGS # AUTHENTICATION & SECURITY SETTINGS
# ================================================================ # ================================================================
JWT_ALGORITHM: "HS256" JWT_ALGORITHM: "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: "30" JWT_ACCESS_TOKEN_EXPIRE_MINUTES: "240"
JWT_REFRESH_TOKEN_EXPIRE_DAYS: "7" JWT_REFRESH_TOKEN_EXPIRE_DAYS: "7"
ENABLE_SERVICE_AUTH: "false" ENABLE_SERVICE_AUTH: "false"
PASSWORD_MIN_LENGTH: "8" PASSWORD_MIN_LENGTH: "8"

View File

@@ -131,15 +131,28 @@ class OnboardingService:
# Update the step # Update the step
await self._update_user_onboarding_data( await self._update_user_onboarding_data(
user_id, user_id,
step_name, step_name,
{ {
"completed": update_request.completed, "completed": update_request.completed,
"completed_at": datetime.now(timezone.utc).isoformat() if update_request.completed else None, "completed_at": datetime.now(timezone.utc).isoformat() if update_request.completed else None,
"data": update_request.data or {} "data": update_request.data or {}
} }
) )
# Try to update summary and handle partial failures gracefully
try:
# Update the user's onboarding summary
await self._update_user_summary(user_id)
except HTTPException as he:
# If it's a 207 Multi-Status (partial success), log warning but continue
if he.status_code == status.HTTP_207_MULTI_STATUS:
logger.warning(f"Summary update failed for user {user_id}, step {step_name}: {he.detail}")
# Continue execution - the step update was successful
else:
# Re-raise other HTTP exceptions
raise
# Return updated progress # Return updated progress
return await self.get_user_progress(user_id) return await self.get_user_progress(user_id)
@@ -284,10 +297,7 @@ class OnboardingService:
completed=completed, completed=completed,
step_data=data_payload step_data=data_payload
) )
# Update the user's onboarding summary
await self._update_user_summary(user_id)
logger.info(f"Successfully updated onboarding step for user {user_id}: {step_name} = {step_data}") logger.info(f"Successfully updated onboarding step for user {user_id}: {step_name} = {step_data}")
return updated_step return updated_step
@@ -300,26 +310,26 @@ class OnboardingService:
try: try:
# Get updated progress # Get updated progress
user_progress_data = await self._get_user_onboarding_data(user_id) user_progress_data = await self._get_user_onboarding_data(user_id)
# Calculate current status # Calculate current status
completed_steps = [] completed_steps = []
for step_name in ONBOARDING_STEPS: for step_name in ONBOARDING_STEPS:
if user_progress_data.get(step_name, {}).get("completed", False): if user_progress_data.get(step_name, {}).get("completed", False):
completed_steps.append(step_name) completed_steps.append(step_name)
# Determine current and next step # Determine current and next step
current_step = self._get_current_step(completed_steps) current_step = self._get_current_step(completed_steps)
next_step = self._get_next_step(completed_steps) next_step = self._get_next_step(completed_steps)
# Calculate completion percentage # Calculate completion percentage
completion_percentage = (len(completed_steps) / len(ONBOARDING_STEPS)) * 100 completion_percentage = (len(completed_steps) / len(ONBOARDING_STEPS)) * 100
# Check if fully completed # Check if fully completed
fully_completed = len(completed_steps) == len(ONBOARDING_STEPS) fully_completed = len(completed_steps) == len(ONBOARDING_STEPS)
# Format steps count # Format steps count
steps_completed_count = f"{len(completed_steps)}/{len(ONBOARDING_STEPS)}" steps_completed_count = f"{len(completed_steps)}/{len(ONBOARDING_STEPS)}"
# Update summary in database # Update summary in database
await self.onboarding_repo.upsert_user_summary( await self.onboarding_repo.upsert_user_summary(
user_id=user_id, user_id=user_id,
@@ -329,10 +339,18 @@ class OnboardingService:
fully_completed=fully_completed, fully_completed=fully_completed,
steps_completed_count=steps_completed_count steps_completed_count=steps_completed_count
) )
logger.debug(f"Successfully updated onboarding summary for user {user_id}")
except Exception as e: except Exception as e:
logger.error(f"Error updating onboarding summary for user {user_id}: {e}") logger.error(f"Error updating onboarding summary for user {user_id}: {e}",
# Don't raise here - summary update failure shouldn't break step updates extra={"user_id": user_id, "error_type": type(e).__name__})
# Raise a warning-level HTTPException to inform frontend without breaking the flow
# This allows the step update to succeed while alerting about summary issues
raise HTTPException(
status_code=status.HTTP_207_MULTI_STATUS,
detail=f"Step updated successfully, but summary update failed: {str(e)}"
)
# API Routes # API Routes

View File

@@ -61,11 +61,11 @@ class UserOnboardingSummary(Base):
# Summary fields # Summary fields
current_step = Column(String(50), nullable=False, default="user_registered") current_step = Column(String(50), nullable=False, default="user_registered")
next_step = Column(String(50)) next_step = Column(String(50))
completion_percentage = Column(String(10), default="0.0") # Store as string for precision completion_percentage = Column(String(50), default="0.0") # Store as string for precision
fully_completed = Column(Boolean, default=False) fully_completed = Column(Boolean, default=False)
# Progress tracking # Progress tracking
steps_completed_count = Column(String(10), default="0") # Store as string: "3/5" steps_completed_count = Column(String(50), default="0") # Store as string: "3/5"
# Timestamps # Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))

View File

@@ -82,11 +82,45 @@ async def training_progress_websocket(
tenant_id: str, tenant_id: str,
job_id: str job_id: str
): ):
connection_id = f"{tenant_id}_{id(websocket)}" # Validate token from query parameters
token = websocket.query_params.get("token")
if not token:
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
await websocket.close(code=1008, reason="Authentication token required")
return
# Validate the token (use the same JWT handler as gateway)
from shared.auth.jwt_handler import JWTHandler
from app.core.config import settings
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload:
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
await websocket.close(code=1008, reason="Invalid authentication token")
return
# Verify user has access to this tenant
user_id = payload.get('user_id')
if not user_id:
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
await websocket.close(code=1008, reason="Invalid token payload")
return
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
except Exception as e:
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
await websocket.close(code=1008, reason="Token validation failed")
return
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
await connection_manager.connect(websocket, job_id, connection_id) await connection_manager.connect(websocket, job_id, connection_id)
logger.info(f"WebSocket connection established for job {job_id}") logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
consumer_task = None consumer_task = None
training_completed = False training_completed = False
@@ -100,11 +134,12 @@ async def training_progress_websocket(
while not training_completed: while not training_completed:
try: try:
# FIXED: Use receive() instead of receive_text() # Coordinate with frontend 30s heartbeat + gateway 45s timeout
# This should be longer than gateway timeout to avoid premature closure
try: try:
data = await asyncio.wait_for(websocket.receive(), timeout=30.0) data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
last_activity = asyncio.get_event_loop().time() last_activity = asyncio.get_event_loop().time()
# Handle different message types # Handle different message types
if data["type"] == "websocket.receive": if data["type"] == "websocket.receive":
if "text" in data: if "text" in data:
@@ -123,31 +158,41 @@ async def training_progress_websocket(
elif message_text == "close": elif message_text == "close":
logger.info(f"Client requested connection close for job {job_id}") logger.info(f"Client requested connection close for job {job_id}")
break break
elif "bytes" in data: elif "bytes" in data:
# Handle binary messages (WebSocket ping frames) # Handle binary messages (WebSocket ping frames) - respond with text pong for compatibility
await websocket.send_text("pong") await websocket.send_text("pong")
logger.debug(f"Binary ping received for job {job_id}") logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
elif data["type"] == "websocket.disconnect": elif data["type"] == "websocket.disconnect":
logger.info(f"WebSocket disconnect message received for job {job_id}") logger.info(f"WebSocket disconnect message received for job {job_id}")
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
# No message received in 30 seconds - send heartbeat # No message received in 60 seconds - this is now coordinated with gateway timeouts
current_time = asyncio.get_event_loop().time() current_time = asyncio.get_event_loop().time()
if current_time - last_activity > 60:
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat") # Send heartbeat only if we haven't received frontend ping for too long
# Frontend sends ping every 30s, so 60s timeout + 30s grace = 90s before heartbeat
try: if current_time - last_activity > 90: # 90 seconds of total inactivity
await websocket.send_json({ logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
"type": "heartbeat",
"job_id": job_id, try:
"timestamp": str(datetime.datetime.now()) await websocket.send_json({
}) "type": "heartbeat",
except Exception as e: "job_id": job_id,
logger.error(f"Failed to send heartbeat for job {job_id}: {e}") "timestamp": str(datetime.datetime.now()),
break "message": "Training service heartbeat - frontend inactive",
"inactivity_seconds": int(current_time - last_activity)
})
last_activity = current_time
except Exception as e:
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
break
else:
# Normal timeout, frontend should be sending ping every 30s
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
continue
except WebSocketDisconnect: except WebSocketDisconnect:
logger.info(f"WebSocket client disconnected for job {job_id}") logger.info(f"WebSocket client disconnected for job {job_id}")

View File

@@ -43,11 +43,71 @@ async def get_db_health() -> bool:
await conn.execute(text("SELECT 1")) await conn.execute(text("SELECT 1"))
logger.debug("Database health check passed") logger.debug("Database health check passed")
return True return True
except Exception as e: except Exception as e:
logger.error("Database health check failed", error=str(e)) logger.error("Database health check failed", error=str(e))
return False return False
async def get_comprehensive_db_health() -> dict:
"""
Comprehensive health check that verifies both connectivity and table existence
"""
health_status = {
"status": "healthy",
"connectivity": False,
"tables_exist": False,
"tables_verified": [],
"missing_tables": [],
"errors": []
}
try:
# Test basic connectivity
health_status["connectivity"] = await get_db_health()
if not health_status["connectivity"]:
health_status["status"] = "unhealthy"
health_status["errors"].append("Database connectivity failed")
return health_status
# Test table existence
tables_verified = await _verify_tables_exist()
health_status["tables_exist"] = tables_verified
if tables_verified:
health_status["tables_verified"] = [
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
]
else:
health_status["status"] = "unhealthy"
health_status["errors"].append("Required tables missing or inaccessible")
# Try to identify which specific tables are missing
try:
async with database_manager.get_session() as session:
for table_name in ['model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts']:
try:
await session.execute(text(f"SELECT 1 FROM {table_name} LIMIT 1"))
health_status["tables_verified"].append(table_name)
except Exception:
health_status["missing_tables"].append(table_name)
except Exception as e:
health_status["errors"].append(f"Error checking individual tables: {str(e)}")
logger.debug("Comprehensive database health check completed",
status=health_status["status"],
connectivity=health_status["connectivity"],
tables_exist=health_status["tables_exist"])
except Exception as e:
health_status["status"] = "unhealthy"
health_status["errors"].append(f"Health check failed: {str(e)}")
logger.error("Comprehensive database health check failed", error=str(e))
return health_status
# Training service specific database utilities # Training service specific database utilities
class TrainingDatabaseUtils: class TrainingDatabaseUtils:
"""Training service specific database utilities""" """Training service specific database utilities"""
@@ -223,27 +283,118 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
# Database initialization for training service # Database initialization for training service
async def initialize_training_database(): async def initialize_training_database():
"""Initialize database tables for training service""" """Initialize database tables for training service with retry logic and verification"""
import asyncio
from sqlalchemy import text
max_retries = 5
retry_delay = 2.0
for attempt in range(1, max_retries + 1):
try:
logger.info("Initializing training service database",
attempt=attempt, max_retries=max_retries)
# Step 1: Test database connectivity first
logger.info("Testing database connectivity...")
connection_ok = await database_manager.test_connection()
if not connection_ok:
raise Exception("Database connection test failed")
logger.info("Database connectivity verified")
# Step 2: Import models to ensure they're registered
logger.info("Importing and registering database models...")
from app.models.training import (
ModelTrainingLog,
TrainedModel,
ModelPerformanceMetric,
TrainingJobQueue,
ModelArtifact
)
# Verify models are registered in metadata
expected_tables = {
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
}
registered_tables = set(Base.metadata.tables.keys())
missing_tables = expected_tables - registered_tables
if missing_tables:
raise Exception(f"Models not properly registered: {missing_tables}")
logger.info("Models registered successfully",
tables=list(registered_tables))
# Step 3: Create tables using shared infrastructure with verification
logger.info("Creating database tables...")
await database_manager.create_tables()
# Step 4: Verify tables were actually created
logger.info("Verifying table creation...")
verification_successful = await _verify_tables_exist()
if not verification_successful:
raise Exception("Table verification failed - tables were not created properly")
logger.info("Training service database initialized and verified successfully",
attempt=attempt)
return
except Exception as e:
logger.error("Database initialization failed",
attempt=attempt,
max_retries=max_retries,
error=str(e))
if attempt == max_retries:
logger.error("All database initialization attempts failed - giving up")
raise Exception(f"Failed to initialize training database after {max_retries} attempts: {str(e)}")
# Wait before retry with exponential backoff
wait_time = retry_delay * (2 ** (attempt - 1))
logger.info("Retrying database initialization",
retry_in_seconds=wait_time,
next_attempt=attempt + 1)
await asyncio.sleep(wait_time)
async def _verify_tables_exist() -> bool:
"""Verify that all required tables exist in the database"""
try: try:
logger.info("Initializing training service database") async with database_manager.get_session() as session:
# Check each required table exists and is accessible
# Import models to ensure they're registered required_tables = [
from app.models.training import ( 'model_training_logs',
ModelTrainingLog, 'trained_models',
TrainedModel, 'model_performance_metrics',
ModelPerformanceMetric, 'training_job_queue',
TrainingJobQueue, 'model_artifacts'
ModelArtifact ]
)
for table_name in required_tables:
# Create tables using shared infrastructure try:
await database_manager.create_tables() # Try to query the table structure
result = await session.execute(
logger.info("Training service database initialized successfully") text(f"SELECT 1 FROM {table_name} LIMIT 1")
)
logger.debug(f"Table {table_name} exists and is accessible")
except Exception as table_error:
# If it's a "relation does not exist" error, table creation failed
if "does not exist" in str(table_error).lower():
logger.error(f"Table {table_name} does not exist", error=str(table_error))
return False
# If it's an empty table, that's fine - table exists
elif "no data" in str(table_error).lower():
logger.debug(f"Table {table_name} exists but is empty (normal)")
else:
logger.warning(f"Unexpected error querying {table_name}", error=str(table_error))
logger.info("All required tables verified successfully",
tables=required_tables)
return True
except Exception as e: except Exception as e:
logger.error("Failed to initialize training service database", error=str(e)) logger.error("Table verification failed", error=str(e))
raise return False
# Database cleanup for training service # Database cleanup for training service
async def cleanup_training_database(): async def cleanup_training_database():

View File

@@ -16,7 +16,7 @@ from fastapi.responses import JSONResponse
import uvicorn import uvicorn
from app.core.config import settings from app.core.config import settings
from app.core.database import initialize_training_database, cleanup_training_database, get_db_health from app.core.database import initialize_training_database, cleanup_training_database, get_db_health, get_comprehensive_db_health
from app.api import training, models from app.api import training, models
from app.api.websocket import websocket_router from app.api.websocket import websocket_router
@@ -195,18 +195,69 @@ async def health_check():
@app.get("/health/ready") @app.get("/health/ready")
async def readiness_check(): async def readiness_check():
"""Kubernetes readiness probe endpoint""" """Kubernetes readiness probe endpoint with comprehensive database checks"""
checks = { try:
"database": await get_db_health(), # Get comprehensive database health including table verification
"application": getattr(app.state, 'ready', False) db_health = await get_comprehensive_db_health()
}
checks = {
if all(checks.values()): "database_connectivity": db_health["connectivity"],
return {"status": "ready", "checks": checks} "database_tables": db_health["tables_exist"],
else: "application": getattr(app.state, 'ready', False)
}
# Include detailed database info for debugging
database_details = {
"status": db_health["status"],
"tables_verified": db_health["tables_verified"],
"missing_tables": db_health["missing_tables"],
"errors": db_health["errors"]
}
# Service is ready only if all checks pass
all_ready = all(checks.values()) and db_health["status"] == "healthy"
if all_ready:
return {
"status": "ready",
"checks": checks,
"database": database_details
}
else:
return JSONResponse(
status_code=503,
content={
"status": "not ready",
"checks": checks,
"database": database_details
}
)
except Exception as e:
logger.error("Readiness check failed", error=str(e))
return JSONResponse( return JSONResponse(
status_code=503, status_code=503,
content={"status": "not ready", "checks": checks} content={
"status": "not ready",
"error": f"Health check failed: {str(e)}"
}
)
@app.get("/health/database")
async def database_health_check():
"""Detailed database health endpoint for debugging"""
try:
db_health = await get_comprehensive_db_health()
status_code = 200 if db_health["status"] == "healthy" else 503
return JSONResponse(status_code=status_code, content=db_health)
except Exception as e:
logger.error("Database health check failed", error=str(e))
return JSONResponse(
status_code=503,
content={
"status": "unhealthy",
"error": f"Health check failed: {str(e)}"
}
) )
@app.get("/metrics") @app.get("/metrics")
@@ -220,11 +271,6 @@ async def get_metrics():
async def liveness_check(): async def liveness_check():
return {"status": "alive"} return {"status": "alive"}
@app.get("/health/ready")
async def readiness_check():
ready = getattr(app.state, 'ready', True)
return {"status": "ready" if ready else "not ready"}
@app.get("/") @app.get("/")
async def root(): async def root():
return {"service": "training-service", "version": "1.0.0"} return {"service": "training-service", "version": "1.0.0"}

View File

@@ -150,12 +150,33 @@ class DatabaseManager:
# ===== TABLE MANAGEMENT ===== # ===== TABLE MANAGEMENT =====
async def create_tables(self, metadata=None): async def create_tables(self, metadata=None):
"""Create database tables""" """Create database tables with enhanced error handling and transaction verification"""
try: try:
target_metadata = metadata or Base.metadata target_metadata = metadata or Base.metadata
table_names = list(target_metadata.tables.keys())
logger.info(f"Creating tables: {table_names}", service=self.service_name)
# Use explicit transaction with proper error handling
async with self.async_engine.begin() as conn: async with self.async_engine.begin() as conn:
await conn.run_sync(target_metadata.create_all, checkfirst=True) try:
# Create tables within the transaction
await conn.run_sync(target_metadata.create_all, checkfirst=True)
# Verify transaction is not in error state
# Try a simple query to ensure connection is still valid
await conn.execute(text("SELECT 1"))
logger.info("Database tables creation transaction completed successfully",
service=self.service_name, tables=table_names)
except Exception as create_error:
logger.error(f"Error during table creation within transaction: {create_error}",
service=self.service_name)
# Re-raise to trigger transaction rollback
raise
logger.info("Database tables created successfully", service=self.service_name) logger.info("Database tables created successfully", service=self.service_name)
except Exception as e: except Exception as e:
# Check if it's a "relation already exists" error which can be safely ignored # Check if it's a "relation already exists" error which can be safely ignored
error_str = str(e).lower() error_str = str(e).lower()
@@ -164,6 +185,14 @@ class DatabaseManager:
logger.info("Database tables creation completed (some already existed)", service=self.service_name) logger.info("Database tables creation completed (some already existed)", service=self.service_name)
else: else:
logger.error(f"Failed to create tables: {e}", service=self.service_name) logger.error(f"Failed to create tables: {e}", service=self.service_name)
# Check for specific transaction error indicators
if any(indicator in error_str for indicator in [
"transaction", "rollback", "aborted", "failed sql transaction"
]):
logger.error("Transaction-related error detected during table creation",
service=self.service_name)
raise DatabaseError(f"Table creation failed: {str(e)}") raise DatabaseError(f"Table creation failed: {str(e)}")
async def drop_tables(self, metadata=None): async def drop_tables(self, metadata=None):