Initial commit - production deployment
This commit is contained in:
707
frontend/src/api/hooks/training.ts
Normal file
707
frontend/src/api/hooks/training.ts
Normal file
@@ -0,0 +1,707 @@
|
||||
/**
|
||||
* Training React Query hooks
|
||||
* Provides data fetching, caching, and state management for training operations
|
||||
*/
|
||||
|
||||
import * as React from 'react';
|
||||
import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOptions } from '@tanstack/react-query';
|
||||
import { trainingService } from '../services/training';
|
||||
import { ApiError, apiClient } from '../client/apiClient';
|
||||
import { useAuthStore } from '../../stores/auth.store';
|
||||
import {
|
||||
HTTP_POLLING_INTERVAL_MS,
|
||||
HTTP_POLLING_DEBOUNCE_MS,
|
||||
WEBSOCKET_HEARTBEAT_INTERVAL_MS,
|
||||
WEBSOCKET_MAX_RECONNECT_ATTEMPTS,
|
||||
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS,
|
||||
WEBSOCKET_RECONNECT_MAX_DELAY_MS,
|
||||
PROGRESS_DATA_ANALYSIS,
|
||||
PROGRESS_TRAINING_RANGE_START,
|
||||
PROGRESS_TRAINING_RANGE_END
|
||||
} from '../../constants/training';
|
||||
import type {
|
||||
TrainingJobRequest,
|
||||
TrainingJobResponse,
|
||||
TrainingJobStatus,
|
||||
SingleProductTrainingRequest,
|
||||
ModelMetricsResponse,
|
||||
TrainedModelResponse,
|
||||
} from '../types/training';
|
||||
|
||||
// Query Keys Factory
|
||||
export const trainingKeys = {
|
||||
all: ['training'] as const,
|
||||
jobs: {
|
||||
all: () => [...trainingKeys.all, 'jobs'] as const,
|
||||
status: (tenantId: string, jobId: string) =>
|
||||
[...trainingKeys.jobs.all(), 'status', tenantId, jobId] as const,
|
||||
},
|
||||
models: {
|
||||
all: () => [...trainingKeys.all, 'models'] as const,
|
||||
lists: () => [...trainingKeys.models.all(), 'list'] as const,
|
||||
list: (tenantId: string, params?: any) =>
|
||||
[...trainingKeys.models.lists(), tenantId, params] as const,
|
||||
details: () => [...trainingKeys.models.all(), 'detail'] as const,
|
||||
detail: (tenantId: string, modelId: string) =>
|
||||
[...trainingKeys.models.details(), tenantId, modelId] as const,
|
||||
active: (tenantId: string, inventoryProductId: string) =>
|
||||
[...trainingKeys.models.all(), 'active', tenantId, inventoryProductId] as const,
|
||||
metrics: (tenantId: string, modelId: string) =>
|
||||
[...trainingKeys.models.all(), 'metrics', tenantId, modelId] as const,
|
||||
performance: (tenantId: string, modelId: string) =>
|
||||
[...trainingKeys.models.all(), 'performance', tenantId, modelId] as const,
|
||||
},
|
||||
statistics: (tenantId: string) =>
|
||||
[...trainingKeys.all, 'statistics', tenantId] as const,
|
||||
} as const;
|
||||
|
||||
// Training Job Queries
|
||||
export const useTrainingJobStatus = (
|
||||
tenantId: string,
|
||||
jobId: string,
|
||||
options?: Omit<UseQueryOptions<TrainingJobStatus, ApiError>, 'queryKey' | 'queryFn'> & {
|
||||
isWebSocketConnected?: boolean;
|
||||
}
|
||||
) => {
|
||||
const { isWebSocketConnected, ...queryOptions } = options || {};
|
||||
const [enablePolling, setEnablePolling] = React.useState(false);
|
||||
|
||||
// Debounce HTTP polling activation: wait after WebSocket disconnects
|
||||
// This prevents race conditions where both WebSocket and HTTP are briefly active
|
||||
React.useEffect(() => {
|
||||
if (!isWebSocketConnected) {
|
||||
const debounceTimer = setTimeout(() => {
|
||||
setEnablePolling(true);
|
||||
console.log(`🔄 HTTP polling enabled after ${HTTP_POLLING_DEBOUNCE_MS}ms debounce (WebSocket disconnected)`);
|
||||
}, HTTP_POLLING_DEBOUNCE_MS);
|
||||
|
||||
return () => clearTimeout(debounceTimer);
|
||||
} else {
|
||||
setEnablePolling(false);
|
||||
console.log('❌ HTTP polling disabled (WebSocket connected)');
|
||||
}
|
||||
}, [isWebSocketConnected]);
|
||||
|
||||
// Completely disable the query when WebSocket is connected or during debounce period
|
||||
const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected && enablePolling;
|
||||
|
||||
console.log('🔄 Training status query:', {
|
||||
tenantId: !!tenantId,
|
||||
jobId: !!jobId,
|
||||
isWebSocketConnected,
|
||||
enablePolling,
|
||||
queryEnabled: isEnabled
|
||||
});
|
||||
|
||||
return useQuery<TrainingJobStatus, ApiError>({
|
||||
queryKey: trainingKeys.jobs.status(tenantId, jobId),
|
||||
queryFn: () => {
|
||||
console.log('📡 Executing HTTP training status query (WebSocket disconnected)');
|
||||
return trainingService.getTrainingJobStatus(tenantId, jobId);
|
||||
},
|
||||
enabled: isEnabled, // Completely disable when WebSocket connected
|
||||
refetchInterval: isEnabled ? (query) => {
|
||||
// Only set up refetch interval if the query is enabled
|
||||
const data = query.state.data;
|
||||
|
||||
// Stop polling if we get auth errors or training is completed
|
||||
if (query.state.error && (query.state.error as any)?.status === 401) {
|
||||
console.log('🚫 Stopping status polling due to auth error');
|
||||
return false;
|
||||
}
|
||||
if (data?.status === 'completed' || data?.status === 'failed') {
|
||||
console.log('🏁 Training completed - stopping HTTP polling');
|
||||
return false; // Stop polling when training is done
|
||||
}
|
||||
|
||||
console.log(`📊 HTTP fallback polling active (WebSocket disconnected) - ${HTTP_POLLING_INTERVAL_MS}ms interval`);
|
||||
return HTTP_POLLING_INTERVAL_MS; // Poll while training (fallback when WebSocket unavailable)
|
||||
} : false, // Completely disable interval when WebSocket connected
|
||||
staleTime: 1000, // Consider data stale after 1 second
|
||||
retry: (failureCount, error) => {
|
||||
// Don't retry on auth errors
|
||||
if ((error as any)?.status === 401) {
|
||||
console.log('🚫 Not retrying due to auth error');
|
||||
return false;
|
||||
}
|
||||
return failureCount < 3;
|
||||
},
|
||||
...queryOptions,
|
||||
});
|
||||
};
|
||||
|
||||
// Model Queries
|
||||
export const useActiveModel = (
|
||||
tenantId: string,
|
||||
inventoryProductId: string,
|
||||
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
|
||||
) => {
|
||||
return useQuery<any, ApiError>({
|
||||
queryKey: trainingKeys.models.active(tenantId, inventoryProductId),
|
||||
queryFn: () => trainingService.getActiveModel(tenantId, inventoryProductId),
|
||||
enabled: !!tenantId && !!inventoryProductId,
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||
...options,
|
||||
});
|
||||
};
|
||||
|
||||
export const useModels = (
|
||||
tenantId: string,
|
||||
queryParams?: any,
|
||||
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
|
||||
) => {
|
||||
return useQuery<any, ApiError>({
|
||||
queryKey: trainingKeys.models.list(tenantId, queryParams),
|
||||
queryFn: () => trainingService.getModels(tenantId, queryParams),
|
||||
enabled: !!tenantId,
|
||||
staleTime: 2 * 60 * 1000, // 2 minutes
|
||||
...options,
|
||||
});
|
||||
};
|
||||
|
||||
export const useModelMetrics = (
|
||||
tenantId: string,
|
||||
modelId: string,
|
||||
options?: Omit<UseQueryOptions<ModelMetricsResponse, ApiError>, 'queryKey' | 'queryFn'>
|
||||
) => {
|
||||
return useQuery<ModelMetricsResponse, ApiError>({
|
||||
queryKey: trainingKeys.models.metrics(tenantId, modelId),
|
||||
queryFn: () => trainingService.getModelMetrics(tenantId, modelId),
|
||||
enabled: !!tenantId && !!modelId,
|
||||
staleTime: 10 * 60 * 1000, // 10 minutes
|
||||
...options,
|
||||
});
|
||||
};
|
||||
|
||||
export const useModelPerformance = (
|
||||
tenantId: string,
|
||||
modelId: string,
|
||||
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
|
||||
) => {
|
||||
return useQuery<any, ApiError>({
|
||||
queryKey: trainingKeys.models.performance(tenantId, modelId),
|
||||
queryFn: () => trainingService.getModelPerformance(tenantId, modelId),
|
||||
enabled: !!tenantId && !!modelId,
|
||||
staleTime: 10 * 60 * 1000, // 10 minutes
|
||||
...options,
|
||||
});
|
||||
};
|
||||
|
||||
// Statistics Queries
|
||||
export const useTenantTrainingStatistics = (
|
||||
tenantId: string,
|
||||
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
|
||||
) => {
|
||||
return useQuery<any, ApiError>({
|
||||
queryKey: trainingKeys.statistics(tenantId),
|
||||
queryFn: () => trainingService.getTenantStatistics(tenantId),
|
||||
enabled: !!tenantId,
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||
...options,
|
||||
});
|
||||
};
|
||||
|
||||
// Training Job Mutations
|
||||
export const useCreateTrainingJob = (
|
||||
options?: UseMutationOptions<
|
||||
TrainingJobResponse,
|
||||
ApiError,
|
||||
{ tenantId: string; request: TrainingJobRequest }
|
||||
>
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useMutation<
|
||||
TrainingJobResponse,
|
||||
ApiError,
|
||||
{ tenantId: string; request: TrainingJobRequest }
|
||||
>({
|
||||
mutationFn: ({ tenantId, request }) => trainingService.createTrainingJob(tenantId, request),
|
||||
onSuccess: (data, { tenantId }) => {
|
||||
// Add the job status to cache
|
||||
queryClient.setQueryData(
|
||||
trainingKeys.jobs.status(tenantId, data.job_id),
|
||||
{
|
||||
job_id: data.job_id,
|
||||
status: data.status,
|
||||
progress: 0,
|
||||
}
|
||||
);
|
||||
|
||||
// Invalidate statistics to reflect the new training job
|
||||
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
|
||||
},
|
||||
...options,
|
||||
});
|
||||
};
|
||||
|
||||
export const useTrainSingleProduct = (
|
||||
options?: UseMutationOptions<
|
||||
TrainingJobResponse,
|
||||
ApiError,
|
||||
{ tenantId: string; inventoryProductId: string; request: SingleProductTrainingRequest }
|
||||
>
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useMutation<
|
||||
TrainingJobResponse,
|
||||
ApiError,
|
||||
{ tenantId: string; inventoryProductId: string; request: SingleProductTrainingRequest }
|
||||
>({
|
||||
mutationFn: ({ tenantId, inventoryProductId, request }) =>
|
||||
trainingService.trainSingleProduct(tenantId, inventoryProductId, request),
|
||||
onSuccess: (data, { tenantId, inventoryProductId }) => {
|
||||
// Add the job status to cache
|
||||
queryClient.setQueryData(
|
||||
trainingKeys.jobs.status(tenantId, data.job_id),
|
||||
{
|
||||
job_id: data.job_id,
|
||||
status: data.status,
|
||||
progress: 0,
|
||||
}
|
||||
);
|
||||
|
||||
// Invalidate active model for this product
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: trainingKeys.models.active(tenantId, inventoryProductId)
|
||||
});
|
||||
|
||||
// Invalidate statistics
|
||||
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
|
||||
},
|
||||
...options,
|
||||
});
|
||||
};
|
||||
|
||||
// Admin Mutations
|
||||
export const useDeleteAllTenantModels = (
|
||||
options?: UseMutationOptions<{ message: string }, ApiError, { tenantId: string }>
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useMutation<{ message: string }, ApiError, { tenantId: string }>({
|
||||
mutationFn: ({ tenantId }) => trainingService.deleteAllTenantModels(tenantId),
|
||||
onSuccess: (_, { tenantId }) => {
|
||||
// Invalidate all model-related queries for this tenant
|
||||
queryClient.invalidateQueries({ queryKey: trainingKeys.models.all() });
|
||||
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
|
||||
},
|
||||
...options,
|
||||
});
|
||||
};
|
||||
|
||||
// WebSocket Hook for Real-time Training Updates
|
||||
export const useTrainingWebSocket = (
|
||||
tenantId: string,
|
||||
jobId: string,
|
||||
token?: string,
|
||||
options?: {
|
||||
onProgress?: (data: any) => void;
|
||||
onCompleted?: (data: any) => void;
|
||||
onError?: (error: any) => void;
|
||||
onStarted?: (data: any) => void;
|
||||
onCancelled?: (data: any) => void;
|
||||
}
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
const [isConnected, setIsConnected] = React.useState(false);
|
||||
const [connectionError, setConnectionError] = React.useState<string | null>(null);
|
||||
const [connectionAttempts, setConnectionAttempts] = React.useState(0);
|
||||
|
||||
// Memoize options to prevent unnecessary effect re-runs
|
||||
const memoizedOptions = React.useMemo(() => options, [
|
||||
options?.onProgress,
|
||||
options?.onCompleted,
|
||||
options?.onError,
|
||||
options?.onStarted,
|
||||
options?.onCancelled
|
||||
]);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (!tenantId || !jobId || !memoizedOptions) {
|
||||
return;
|
||||
}
|
||||
|
||||
let ws: WebSocket | null = null;
|
||||
let reconnectTimer: NodeJS.Timeout | null = null;
|
||||
let isManuallyDisconnected = false;
|
||||
let reconnectAttempts = 0;
|
||||
const maxReconnectAttempts = WEBSOCKET_MAX_RECONNECT_ATTEMPTS;
|
||||
|
||||
const connect = async () => {
|
||||
try {
|
||||
setConnectionError(null);
|
||||
setConnectionAttempts(prev => prev + 1);
|
||||
|
||||
// Use centralized token management from apiClient
|
||||
let effectiveToken: string | null;
|
||||
|
||||
try {
|
||||
// Always use the apiClient's token management
|
||||
effectiveToken = await apiClient.ensureValidToken();
|
||||
|
||||
if (!effectiveToken) {
|
||||
throw new Error('No valid token available');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('❌ Failed to get valid token for WebSocket:', error);
|
||||
setConnectionError('Authentication failed. Please log in again.');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`🔄 Attempting WebSocket connection (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts + 1}):`, {
|
||||
tenantId,
|
||||
jobId,
|
||||
hasToken: !!effectiveToken,
|
||||
tokenFromApiClient: true
|
||||
});
|
||||
|
||||
ws = trainingService.createWebSocketConnection(tenantId, jobId, effectiveToken);
|
||||
|
||||
ws.onopen = () => {
|
||||
console.log('✅ Training WebSocket connected successfully', {
|
||||
readyState: ws?.readyState,
|
||||
url: ws?.url,
|
||||
jobId
|
||||
});
|
||||
// Track connection time for debugging
|
||||
(ws as any)._connectTime = Date.now();
|
||||
setIsConnected(true);
|
||||
reconnectAttempts = 0; // Reset on successful connection
|
||||
|
||||
// Request current status on connection
|
||||
try {
|
||||
ws?.send('get_status');
|
||||
console.log('📤 Requested current training status');
|
||||
} catch (e) {
|
||||
console.warn('Failed to request status on connection:', e);
|
||||
}
|
||||
|
||||
// Helper function to check if tokens represent different auth users/sessions
|
||||
const isNewAuthSession = (oldToken: string, newToken: string): boolean => {
|
||||
if (!oldToken || !newToken) return !!oldToken !== !!newToken;
|
||||
|
||||
try {
|
||||
const oldPayload = JSON.parse(atob(oldToken.split('.')[1]));
|
||||
const newPayload = JSON.parse(atob(newToken.split('.')[1]));
|
||||
|
||||
// Compare by user ID - different user means new auth session
|
||||
// If user_id is same, it's just a token refresh, no need to reconnect
|
||||
return oldPayload.user_id !== newPayload.user_id ||
|
||||
oldPayload.sub !== newPayload.sub;
|
||||
} catch (e) {
|
||||
console.warn('Failed to parse token for session comparison:', e);
|
||||
// On parse error, don't reconnect (assume same session)
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// Set up periodic ping and check for auth session changes
|
||||
const heartbeatInterval = setInterval(async () => {
|
||||
if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) {
|
||||
try {
|
||||
// Check token validity (this may refresh if needed)
|
||||
const currentToken = await apiClient.ensureValidToken();
|
||||
|
||||
// Only reconnect if user changed (new auth session)
|
||||
if (currentToken && effectiveToken && isNewAuthSession(effectiveToken, currentToken)) {
|
||||
console.log('🔄 Auth session changed (different user) - reconnecting WebSocket');
|
||||
ws?.close(1000, 'Auth session changed - reconnecting');
|
||||
clearInterval(heartbeatInterval);
|
||||
return;
|
||||
}
|
||||
|
||||
// Token may have been refreshed but it's the same user - continue
|
||||
if (currentToken && currentToken !== effectiveToken) {
|
||||
console.log('ℹ️ Token refreshed (same user) - updating reference');
|
||||
effectiveToken = currentToken;
|
||||
}
|
||||
|
||||
// Send ping
|
||||
ws?.send('ping');
|
||||
console.log('💓 Sent ping to server');
|
||||
} catch (e) {
|
||||
console.warn('Failed to send ping or validate token:', e);
|
||||
clearInterval(heartbeatInterval);
|
||||
}
|
||||
} else {
|
||||
clearInterval(heartbeatInterval);
|
||||
}
|
||||
}, WEBSOCKET_HEARTBEAT_INTERVAL_MS); // Check for auth changes and send ping
|
||||
|
||||
// Store interval for cleanup
|
||||
(ws as any).heartbeatInterval = heartbeatInterval;
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
// Handle non-JSON messages (like pong responses)
|
||||
if (typeof event.data === 'string' && event.data === 'pong') {
|
||||
console.log('🏓 Pong received from server');
|
||||
return;
|
||||
}
|
||||
|
||||
const message = JSON.parse(event.data);
|
||||
|
||||
console.log('🔔 Training WebSocket message received:', message);
|
||||
|
||||
// Handle initial state message to restore the latest known state
|
||||
if (message.type === 'initial_state') {
|
||||
console.log('📥 Received initial state:', message.data);
|
||||
const initialData = message.data;
|
||||
const initialEventData = initialData.data || {};
|
||||
let initialProgress = initialEventData.progress || 0;
|
||||
|
||||
// Calculate progress for product_completed events
|
||||
if (initialData.type === 'product_completed') {
|
||||
const productsCompleted = initialEventData.products_completed || 0;
|
||||
const totalProducts = initialEventData.total_products || 1;
|
||||
const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
|
||||
initialProgress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
|
||||
console.log('📦 Product training completed in initial state',
|
||||
`${productsCompleted}/${totalProducts}`,
|
||||
`progress: ${initialProgress}%`);
|
||||
}
|
||||
|
||||
// Update job status in cache with initial state
|
||||
queryClient.setQueryData(
|
||||
trainingKeys.jobs.status(tenantId, jobId),
|
||||
(oldData: TrainingJobStatus | undefined) => ({
|
||||
...oldData,
|
||||
job_id: jobId,
|
||||
status: initialData.type === 'completed' ? 'completed' :
|
||||
initialData.type === 'failed' ? 'failed' :
|
||||
initialData.type === 'started' ? 'running' :
|
||||
initialData.type === 'progress' ? 'running' :
|
||||
initialData.type === 'product_completed' ? 'running' :
|
||||
initialData.type === 'step_completed' ? 'running' :
|
||||
oldData?.status || 'running',
|
||||
progress: typeof initialProgress === 'number' ? initialProgress : oldData?.progress || 0,
|
||||
current_step: initialEventData.current_step || initialEventData.step_name || oldData?.current_step,
|
||||
})
|
||||
);
|
||||
return; // Initial state messages are only for state restoration, don't process as regular events
|
||||
}
|
||||
|
||||
// Extract data from backend message structure
|
||||
const eventData = message.data || {};
|
||||
let progress = eventData.progress || 0;
|
||||
const currentStep = eventData.current_step || eventData.step_name || '';
|
||||
const stepDetails = eventData.step_details || '';
|
||||
|
||||
// Handle product_completed events - calculate progress dynamically
|
||||
if (message.type === 'product_completed') {
|
||||
const productsCompleted = eventData.products_completed || 0;
|
||||
const totalProducts = eventData.total_products || 1;
|
||||
|
||||
// Calculate progress: DATA_ANALYSIS% base + (completed/total * (TRAINING_RANGE_END - DATA_ANALYSIS)%)
|
||||
const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
|
||||
progress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
|
||||
|
||||
console.log('📦 Product training completed',
|
||||
`${productsCompleted}/${totalProducts}`,
|
||||
`progress: ${progress}%`);
|
||||
}
|
||||
|
||||
// Update job status in cache
|
||||
queryClient.setQueryData(
|
||||
trainingKeys.jobs.status(tenantId, jobId),
|
||||
(oldData: TrainingJobStatus | undefined) => ({
|
||||
...oldData,
|
||||
job_id: jobId,
|
||||
status: message.type === 'completed' ? 'completed' :
|
||||
message.type === 'failed' ? 'failed' :
|
||||
message.type === 'started' ? 'running' :
|
||||
oldData?.status || 'running',
|
||||
progress: typeof progress === 'number' ? progress : oldData?.progress || 0,
|
||||
current_step: currentStep || oldData?.current_step,
|
||||
})
|
||||
);
|
||||
|
||||
// Call appropriate callback based on message type
|
||||
switch (message.type) {
|
||||
case 'connected':
|
||||
console.log('🔗 WebSocket connected');
|
||||
break;
|
||||
|
||||
case 'started':
|
||||
console.log('🚀 Training started');
|
||||
memoizedOptions?.onStarted?.(message);
|
||||
break;
|
||||
|
||||
case 'progress':
|
||||
console.log('📊 Training progress update', `${progress}%`);
|
||||
memoizedOptions?.onProgress?.(message);
|
||||
break;
|
||||
|
||||
case 'product_completed':
|
||||
console.log('✅ Product training completed');
|
||||
// Treat as progress update
|
||||
memoizedOptions?.onProgress?.({
|
||||
...message,
|
||||
data: {
|
||||
...eventData,
|
||||
progress, // Use calculated progress
|
||||
}
|
||||
});
|
||||
break;
|
||||
|
||||
case 'step_completed':
|
||||
console.log('📋 Step completed');
|
||||
memoizedOptions?.onProgress?.(message);
|
||||
break;
|
||||
|
||||
case 'completed':
|
||||
console.log('✅ Training completed successfully');
|
||||
memoizedOptions?.onCompleted?.(message);
|
||||
// Invalidate models and statistics
|
||||
queryClient.invalidateQueries({ queryKey: trainingKeys.models.all() });
|
||||
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
|
||||
isManuallyDisconnected = true;
|
||||
break;
|
||||
|
||||
case 'failed':
|
||||
console.log('❌ Training failed');
|
||||
memoizedOptions?.onError?.(message);
|
||||
isManuallyDisconnected = true;
|
||||
break;
|
||||
|
||||
default:
|
||||
console.log(`🔍 Unknown message type: ${message.type}`);
|
||||
break;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error parsing WebSocket message:', error);
|
||||
setConnectionError('Error parsing message from server');
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = (error) => {
|
||||
console.error('Training WebSocket error:', error);
|
||||
setConnectionError('WebSocket connection error');
|
||||
setIsConnected(false);
|
||||
};
|
||||
|
||||
ws.onclose = (event) => {
|
||||
console.log(`❌ Training WebSocket disconnected. Code: ${event.code}, Reason: "${event.reason}"`, {
|
||||
wasClean: event.wasClean,
|
||||
jobId,
|
||||
timeConnected: ws ? `${Date.now() - (ws as any)._connectTime || 0}ms` : 'unknown',
|
||||
reconnectAttempts
|
||||
});
|
||||
setIsConnected(false);
|
||||
|
||||
// Detailed logging for different close codes
|
||||
switch (event.code) {
|
||||
case 1000:
|
||||
if (event.reason === 'Auth session changed - reconnecting') {
|
||||
console.log('🔄 WebSocket closed for auth session change - will reconnect immediately');
|
||||
} else {
|
||||
console.log('🔒 WebSocket closed normally');
|
||||
}
|
||||
break;
|
||||
case 1006:
|
||||
console.log('⚠️ WebSocket closed abnormally (1006) - likely server-side issue or network problem');
|
||||
break;
|
||||
case 1001:
|
||||
console.log('🔄 WebSocket endpoint going away');
|
||||
break;
|
||||
case 1003:
|
||||
console.log('❌ WebSocket unsupported data received');
|
||||
break;
|
||||
default:
|
||||
console.log(`❓ WebSocket closed with code ${event.code}`);
|
||||
}
|
||||
|
||||
// Handle auth session change reconnection (immediate reconnect)
|
||||
if (event.code === 1000 && event.reason === 'Auth session changed - reconnecting') {
|
||||
console.log('🔄 Reconnecting immediately due to auth session change...');
|
||||
reconnectTimer = setTimeout(() => {
|
||||
connect(); // Reconnect immediately with new session token
|
||||
}, WEBSOCKET_RECONNECT_INITIAL_DELAY_MS); // Short delay to allow cleanup
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to reconnect if not manually disconnected and haven't exceeded max attempts
|
||||
if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) {
|
||||
const delay = Math.min(
|
||||
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS * Math.pow(2, reconnectAttempts),
|
||||
WEBSOCKET_RECONNECT_MAX_DELAY_MS
|
||||
); // Exponential backoff
|
||||
console.log(`🔄 Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`);
|
||||
|
||||
reconnectTimer = setTimeout(() => {
|
||||
reconnectAttempts++;
|
||||
connect();
|
||||
}, delay);
|
||||
} else if (reconnectAttempts >= maxReconnectAttempts) {
|
||||
console.log(`❌ Max reconnection attempts (${maxReconnectAttempts}) reached. Giving up.`);
|
||||
setConnectionError(`Connection failed after ${maxReconnectAttempts} attempts. The training job may not exist or the server may be unavailable.`);
|
||||
}
|
||||
};
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error creating WebSocket connection:', error);
|
||||
setConnectionError('Failed to create WebSocket connection');
|
||||
}
|
||||
};
|
||||
|
||||
// Connect immediately to avoid missing early progress updates
|
||||
console.log('🚀 Starting immediate WebSocket connection...');
|
||||
connect();
|
||||
|
||||
// Cleanup function
|
||||
return () => {
|
||||
isManuallyDisconnected = true;
|
||||
|
||||
if (reconnectTimer) {
|
||||
clearTimeout(reconnectTimer);
|
||||
}
|
||||
|
||||
if (ws) {
|
||||
ws.close(1000, 'Component unmounted');
|
||||
}
|
||||
|
||||
setIsConnected(false);
|
||||
};
|
||||
}, [tenantId, jobId, queryClient, memoizedOptions]);
|
||||
|
||||
return {
|
||||
isConnected,
|
||||
connectionError
|
||||
};
|
||||
};
|
||||
|
||||
// Utility Hooks
|
||||
export const useIsTrainingInProgress = (
|
||||
tenantId: string,
|
||||
jobId?: string,
|
||||
isWebSocketConnected?: boolean
|
||||
) => {
|
||||
const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', {
|
||||
enabled: !!jobId,
|
||||
isWebSocketConnected,
|
||||
});
|
||||
|
||||
return jobStatus?.status === 'running' || jobStatus?.status === 'pending';
|
||||
};
|
||||
|
||||
export const useTrainingProgress = (
|
||||
tenantId: string,
|
||||
jobId?: string,
|
||||
isWebSocketConnected?: boolean
|
||||
) => {
|
||||
const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', {
|
||||
enabled: !!jobId,
|
||||
isWebSocketConnected,
|
||||
});
|
||||
|
||||
return {
|
||||
progress: jobStatus?.progress || 0,
|
||||
currentStep: jobStatus?.current_step,
|
||||
isComplete: jobStatus?.status === 'completed',
|
||||
isFailed: jobStatus?.status === 'failed',
|
||||
isRunning: jobStatus?.status === 'running',
|
||||
};
|
||||
};
|
||||
Reference in New Issue
Block a user