708 lines
26 KiB
TypeScript
708 lines
26 KiB
TypeScript
|
|
/**
|
|||
|
|
* Training React Query hooks
|
|||
|
|
* Provides data fetching, caching, and state management for training operations
|
|||
|
|
*/
|
|||
|
|
|
|||
|
|
import * as React from 'react';
|
|||
|
|
import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOptions } from '@tanstack/react-query';
|
|||
|
|
import { trainingService } from '../services/training';
|
|||
|
|
import { ApiError, apiClient } from '../client/apiClient';
|
|||
|
|
import { useAuthStore } from '../../stores/auth.store';
|
|||
|
|
import {
|
|||
|
|
HTTP_POLLING_INTERVAL_MS,
|
|||
|
|
HTTP_POLLING_DEBOUNCE_MS,
|
|||
|
|
WEBSOCKET_HEARTBEAT_INTERVAL_MS,
|
|||
|
|
WEBSOCKET_MAX_RECONNECT_ATTEMPTS,
|
|||
|
|
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS,
|
|||
|
|
WEBSOCKET_RECONNECT_MAX_DELAY_MS,
|
|||
|
|
PROGRESS_DATA_ANALYSIS,
|
|||
|
|
PROGRESS_TRAINING_RANGE_START,
|
|||
|
|
PROGRESS_TRAINING_RANGE_END
|
|||
|
|
} from '../../constants/training';
|
|||
|
|
import type {
|
|||
|
|
TrainingJobRequest,
|
|||
|
|
TrainingJobResponse,
|
|||
|
|
TrainingJobStatus,
|
|||
|
|
SingleProductTrainingRequest,
|
|||
|
|
ModelMetricsResponse,
|
|||
|
|
TrainedModelResponse,
|
|||
|
|
} from '../types/training';
|
|||
|
|
|
|||
|
|
// Query Keys Factory
|
|||
|
|
export const trainingKeys = {
|
|||
|
|
all: ['training'] as const,
|
|||
|
|
jobs: {
|
|||
|
|
all: () => [...trainingKeys.all, 'jobs'] as const,
|
|||
|
|
status: (tenantId: string, jobId: string) =>
|
|||
|
|
[...trainingKeys.jobs.all(), 'status', tenantId, jobId] as const,
|
|||
|
|
},
|
|||
|
|
models: {
|
|||
|
|
all: () => [...trainingKeys.all, 'models'] as const,
|
|||
|
|
lists: () => [...trainingKeys.models.all(), 'list'] as const,
|
|||
|
|
list: (tenantId: string, params?: any) =>
|
|||
|
|
[...trainingKeys.models.lists(), tenantId, params] as const,
|
|||
|
|
details: () => [...trainingKeys.models.all(), 'detail'] as const,
|
|||
|
|
detail: (tenantId: string, modelId: string) =>
|
|||
|
|
[...trainingKeys.models.details(), tenantId, modelId] as const,
|
|||
|
|
active: (tenantId: string, inventoryProductId: string) =>
|
|||
|
|
[...trainingKeys.models.all(), 'active', tenantId, inventoryProductId] as const,
|
|||
|
|
metrics: (tenantId: string, modelId: string) =>
|
|||
|
|
[...trainingKeys.models.all(), 'metrics', tenantId, modelId] as const,
|
|||
|
|
performance: (tenantId: string, modelId: string) =>
|
|||
|
|
[...trainingKeys.models.all(), 'performance', tenantId, modelId] as const,
|
|||
|
|
},
|
|||
|
|
statistics: (tenantId: string) =>
|
|||
|
|
[...trainingKeys.all, 'statistics', tenantId] as const,
|
|||
|
|
} as const;
|
|||
|
|
|
|||
|
|
// Training Job Queries
|
|||
|
|
export const useTrainingJobStatus = (
|
|||
|
|
tenantId: string,
|
|||
|
|
jobId: string,
|
|||
|
|
options?: Omit<UseQueryOptions<TrainingJobStatus, ApiError>, 'queryKey' | 'queryFn'> & {
|
|||
|
|
isWebSocketConnected?: boolean;
|
|||
|
|
}
|
|||
|
|
) => {
|
|||
|
|
const { isWebSocketConnected, ...queryOptions } = options || {};
|
|||
|
|
const [enablePolling, setEnablePolling] = React.useState(false);
|
|||
|
|
|
|||
|
|
// Debounce HTTP polling activation: wait after WebSocket disconnects
|
|||
|
|
// This prevents race conditions where both WebSocket and HTTP are briefly active
|
|||
|
|
React.useEffect(() => {
|
|||
|
|
if (!isWebSocketConnected) {
|
|||
|
|
const debounceTimer = setTimeout(() => {
|
|||
|
|
setEnablePolling(true);
|
|||
|
|
console.log(`🔄 HTTP polling enabled after ${HTTP_POLLING_DEBOUNCE_MS}ms debounce (WebSocket disconnected)`);
|
|||
|
|
}, HTTP_POLLING_DEBOUNCE_MS);
|
|||
|
|
|
|||
|
|
return () => clearTimeout(debounceTimer);
|
|||
|
|
} else {
|
|||
|
|
setEnablePolling(false);
|
|||
|
|
console.log('❌ HTTP polling disabled (WebSocket connected)');
|
|||
|
|
}
|
|||
|
|
}, [isWebSocketConnected]);
|
|||
|
|
|
|||
|
|
// Completely disable the query when WebSocket is connected or during debounce period
|
|||
|
|
const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected && enablePolling;
|
|||
|
|
|
|||
|
|
console.log('🔄 Training status query:', {
|
|||
|
|
tenantId: !!tenantId,
|
|||
|
|
jobId: !!jobId,
|
|||
|
|
isWebSocketConnected,
|
|||
|
|
enablePolling,
|
|||
|
|
queryEnabled: isEnabled
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
return useQuery<TrainingJobStatus, ApiError>({
|
|||
|
|
queryKey: trainingKeys.jobs.status(tenantId, jobId),
|
|||
|
|
queryFn: () => {
|
|||
|
|
console.log('📡 Executing HTTP training status query (WebSocket disconnected)');
|
|||
|
|
return trainingService.getTrainingJobStatus(tenantId, jobId);
|
|||
|
|
},
|
|||
|
|
enabled: isEnabled, // Completely disable when WebSocket connected
|
|||
|
|
refetchInterval: isEnabled ? (query) => {
|
|||
|
|
// Only set up refetch interval if the query is enabled
|
|||
|
|
const data = query.state.data;
|
|||
|
|
|
|||
|
|
// Stop polling if we get auth errors or training is completed
|
|||
|
|
if (query.state.error && (query.state.error as any)?.status === 401) {
|
|||
|
|
console.log('🚫 Stopping status polling due to auth error');
|
|||
|
|
return false;
|
|||
|
|
}
|
|||
|
|
if (data?.status === 'completed' || data?.status === 'failed') {
|
|||
|
|
console.log('🏁 Training completed - stopping HTTP polling');
|
|||
|
|
return false; // Stop polling when training is done
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
console.log(`📊 HTTP fallback polling active (WebSocket disconnected) - ${HTTP_POLLING_INTERVAL_MS}ms interval`);
|
|||
|
|
return HTTP_POLLING_INTERVAL_MS; // Poll while training (fallback when WebSocket unavailable)
|
|||
|
|
} : false, // Completely disable interval when WebSocket connected
|
|||
|
|
staleTime: 1000, // Consider data stale after 1 second
|
|||
|
|
retry: (failureCount, error) => {
|
|||
|
|
// Don't retry on auth errors
|
|||
|
|
if ((error as any)?.status === 401) {
|
|||
|
|
console.log('🚫 Not retrying due to auth error');
|
|||
|
|
return false;
|
|||
|
|
}
|
|||
|
|
return failureCount < 3;
|
|||
|
|
},
|
|||
|
|
...queryOptions,
|
|||
|
|
});
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// Model Queries
|
|||
|
|
export const useActiveModel = (
|
|||
|
|
tenantId: string,
|
|||
|
|
inventoryProductId: string,
|
|||
|
|
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
|
|||
|
|
) => {
|
|||
|
|
return useQuery<any, ApiError>({
|
|||
|
|
queryKey: trainingKeys.models.active(tenantId, inventoryProductId),
|
|||
|
|
queryFn: () => trainingService.getActiveModel(tenantId, inventoryProductId),
|
|||
|
|
enabled: !!tenantId && !!inventoryProductId,
|
|||
|
|
staleTime: 5 * 60 * 1000, // 5 minutes
|
|||
|
|
...options,
|
|||
|
|
});
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
export const useModels = (
|
|||
|
|
tenantId: string,
|
|||
|
|
queryParams?: any,
|
|||
|
|
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
|
|||
|
|
) => {
|
|||
|
|
return useQuery<any, ApiError>({
|
|||
|
|
queryKey: trainingKeys.models.list(tenantId, queryParams),
|
|||
|
|
queryFn: () => trainingService.getModels(tenantId, queryParams),
|
|||
|
|
enabled: !!tenantId,
|
|||
|
|
staleTime: 2 * 60 * 1000, // 2 minutes
|
|||
|
|
...options,
|
|||
|
|
});
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
export const useModelMetrics = (
|
|||
|
|
tenantId: string,
|
|||
|
|
modelId: string,
|
|||
|
|
options?: Omit<UseQueryOptions<ModelMetricsResponse, ApiError>, 'queryKey' | 'queryFn'>
|
|||
|
|
) => {
|
|||
|
|
return useQuery<ModelMetricsResponse, ApiError>({
|
|||
|
|
queryKey: trainingKeys.models.metrics(tenantId, modelId),
|
|||
|
|
queryFn: () => trainingService.getModelMetrics(tenantId, modelId),
|
|||
|
|
enabled: !!tenantId && !!modelId,
|
|||
|
|
staleTime: 10 * 60 * 1000, // 10 minutes
|
|||
|
|
...options,
|
|||
|
|
});
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
export const useModelPerformance = (
|
|||
|
|
tenantId: string,
|
|||
|
|
modelId: string,
|
|||
|
|
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
|
|||
|
|
) => {
|
|||
|
|
return useQuery<any, ApiError>({
|
|||
|
|
queryKey: trainingKeys.models.performance(tenantId, modelId),
|
|||
|
|
queryFn: () => trainingService.getModelPerformance(tenantId, modelId),
|
|||
|
|
enabled: !!tenantId && !!modelId,
|
|||
|
|
staleTime: 10 * 60 * 1000, // 10 minutes
|
|||
|
|
...options,
|
|||
|
|
});
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// Statistics Queries
|
|||
|
|
export const useTenantTrainingStatistics = (
|
|||
|
|
tenantId: string,
|
|||
|
|
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
|
|||
|
|
) => {
|
|||
|
|
return useQuery<any, ApiError>({
|
|||
|
|
queryKey: trainingKeys.statistics(tenantId),
|
|||
|
|
queryFn: () => trainingService.getTenantStatistics(tenantId),
|
|||
|
|
enabled: !!tenantId,
|
|||
|
|
staleTime: 5 * 60 * 1000, // 5 minutes
|
|||
|
|
...options,
|
|||
|
|
});
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// Training Job Mutations
|
|||
|
|
export const useCreateTrainingJob = (
|
|||
|
|
options?: UseMutationOptions<
|
|||
|
|
TrainingJobResponse,
|
|||
|
|
ApiError,
|
|||
|
|
{ tenantId: string; request: TrainingJobRequest }
|
|||
|
|
>
|
|||
|
|
) => {
|
|||
|
|
const queryClient = useQueryClient();
|
|||
|
|
|
|||
|
|
return useMutation<
|
|||
|
|
TrainingJobResponse,
|
|||
|
|
ApiError,
|
|||
|
|
{ tenantId: string; request: TrainingJobRequest }
|
|||
|
|
>({
|
|||
|
|
mutationFn: ({ tenantId, request }) => trainingService.createTrainingJob(tenantId, request),
|
|||
|
|
onSuccess: (data, { tenantId }) => {
|
|||
|
|
// Add the job status to cache
|
|||
|
|
queryClient.setQueryData(
|
|||
|
|
trainingKeys.jobs.status(tenantId, data.job_id),
|
|||
|
|
{
|
|||
|
|
job_id: data.job_id,
|
|||
|
|
status: data.status,
|
|||
|
|
progress: 0,
|
|||
|
|
}
|
|||
|
|
);
|
|||
|
|
|
|||
|
|
// Invalidate statistics to reflect the new training job
|
|||
|
|
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
|
|||
|
|
},
|
|||
|
|
...options,
|
|||
|
|
});
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
export const useTrainSingleProduct = (
|
|||
|
|
options?: UseMutationOptions<
|
|||
|
|
TrainingJobResponse,
|
|||
|
|
ApiError,
|
|||
|
|
{ tenantId: string; inventoryProductId: string; request: SingleProductTrainingRequest }
|
|||
|
|
>
|
|||
|
|
) => {
|
|||
|
|
const queryClient = useQueryClient();
|
|||
|
|
|
|||
|
|
return useMutation<
|
|||
|
|
TrainingJobResponse,
|
|||
|
|
ApiError,
|
|||
|
|
{ tenantId: string; inventoryProductId: string; request: SingleProductTrainingRequest }
|
|||
|
|
>({
|
|||
|
|
mutationFn: ({ tenantId, inventoryProductId, request }) =>
|
|||
|
|
trainingService.trainSingleProduct(tenantId, inventoryProductId, request),
|
|||
|
|
onSuccess: (data, { tenantId, inventoryProductId }) => {
|
|||
|
|
// Add the job status to cache
|
|||
|
|
queryClient.setQueryData(
|
|||
|
|
trainingKeys.jobs.status(tenantId, data.job_id),
|
|||
|
|
{
|
|||
|
|
job_id: data.job_id,
|
|||
|
|
status: data.status,
|
|||
|
|
progress: 0,
|
|||
|
|
}
|
|||
|
|
);
|
|||
|
|
|
|||
|
|
// Invalidate active model for this product
|
|||
|
|
queryClient.invalidateQueries({
|
|||
|
|
queryKey: trainingKeys.models.active(tenantId, inventoryProductId)
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
// Invalidate statistics
|
|||
|
|
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
|
|||
|
|
},
|
|||
|
|
...options,
|
|||
|
|
});
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// Admin Mutations
|
|||
|
|
export const useDeleteAllTenantModels = (
|
|||
|
|
options?: UseMutationOptions<{ message: string }, ApiError, { tenantId: string }>
|
|||
|
|
) => {
|
|||
|
|
const queryClient = useQueryClient();
|
|||
|
|
|
|||
|
|
return useMutation<{ message: string }, ApiError, { tenantId: string }>({
|
|||
|
|
mutationFn: ({ tenantId }) => trainingService.deleteAllTenantModels(tenantId),
|
|||
|
|
onSuccess: (_, { tenantId }) => {
|
|||
|
|
// Invalidate all model-related queries for this tenant
|
|||
|
|
queryClient.invalidateQueries({ queryKey: trainingKeys.models.all() });
|
|||
|
|
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
|
|||
|
|
},
|
|||
|
|
...options,
|
|||
|
|
});
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// WebSocket Hook for Real-time Training Updates
|
|||
|
|
export const useTrainingWebSocket = (
|
|||
|
|
tenantId: string,
|
|||
|
|
jobId: string,
|
|||
|
|
token?: string,
|
|||
|
|
options?: {
|
|||
|
|
onProgress?: (data: any) => void;
|
|||
|
|
onCompleted?: (data: any) => void;
|
|||
|
|
onError?: (error: any) => void;
|
|||
|
|
onStarted?: (data: any) => void;
|
|||
|
|
onCancelled?: (data: any) => void;
|
|||
|
|
}
|
|||
|
|
) => {
|
|||
|
|
const queryClient = useQueryClient();
|
|||
|
|
const [isConnected, setIsConnected] = React.useState(false);
|
|||
|
|
const [connectionError, setConnectionError] = React.useState<string | null>(null);
|
|||
|
|
const [connectionAttempts, setConnectionAttempts] = React.useState(0);
|
|||
|
|
|
|||
|
|
// Memoize options to prevent unnecessary effect re-runs
|
|||
|
|
const memoizedOptions = React.useMemo(() => options, [
|
|||
|
|
options?.onProgress,
|
|||
|
|
options?.onCompleted,
|
|||
|
|
options?.onError,
|
|||
|
|
options?.onStarted,
|
|||
|
|
options?.onCancelled
|
|||
|
|
]);
|
|||
|
|
|
|||
|
|
React.useEffect(() => {
|
|||
|
|
if (!tenantId || !jobId || !memoizedOptions) {
|
|||
|
|
return;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let ws: WebSocket | null = null;
|
|||
|
|
let reconnectTimer: NodeJS.Timeout | null = null;
|
|||
|
|
let isManuallyDisconnected = false;
|
|||
|
|
let reconnectAttempts = 0;
|
|||
|
|
const maxReconnectAttempts = WEBSOCKET_MAX_RECONNECT_ATTEMPTS;
|
|||
|
|
|
|||
|
|
const connect = async () => {
|
|||
|
|
try {
|
|||
|
|
setConnectionError(null);
|
|||
|
|
setConnectionAttempts(prev => prev + 1);
|
|||
|
|
|
|||
|
|
// Use centralized token management from apiClient
|
|||
|
|
let effectiveToken: string | null;
|
|||
|
|
|
|||
|
|
try {
|
|||
|
|
// Always use the apiClient's token management
|
|||
|
|
effectiveToken = await apiClient.ensureValidToken();
|
|||
|
|
|
|||
|
|
if (!effectiveToken) {
|
|||
|
|
throw new Error('No valid token available');
|
|||
|
|
}
|
|||
|
|
} catch (error) {
|
|||
|
|
console.error('❌ Failed to get valid token for WebSocket:', error);
|
|||
|
|
setConnectionError('Authentication failed. Please log in again.');
|
|||
|
|
return;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
console.log(`🔄 Attempting WebSocket connection (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts + 1}):`, {
|
|||
|
|
tenantId,
|
|||
|
|
jobId,
|
|||
|
|
hasToken: !!effectiveToken,
|
|||
|
|
tokenFromApiClient: true
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
ws = trainingService.createWebSocketConnection(tenantId, jobId, effectiveToken);
|
|||
|
|
|
|||
|
|
ws.onopen = () => {
|
|||
|
|
console.log('✅ Training WebSocket connected successfully', {
|
|||
|
|
readyState: ws?.readyState,
|
|||
|
|
url: ws?.url,
|
|||
|
|
jobId
|
|||
|
|
});
|
|||
|
|
// Track connection time for debugging
|
|||
|
|
(ws as any)._connectTime = Date.now();
|
|||
|
|
setIsConnected(true);
|
|||
|
|
reconnectAttempts = 0; // Reset on successful connection
|
|||
|
|
|
|||
|
|
// Request current status on connection
|
|||
|
|
try {
|
|||
|
|
ws?.send('get_status');
|
|||
|
|
console.log('📤 Requested current training status');
|
|||
|
|
} catch (e) {
|
|||
|
|
console.warn('Failed to request status on connection:', e);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Helper function to check if tokens represent different auth users/sessions
|
|||
|
|
const isNewAuthSession = (oldToken: string, newToken: string): boolean => {
|
|||
|
|
if (!oldToken || !newToken) return !!oldToken !== !!newToken;
|
|||
|
|
|
|||
|
|
try {
|
|||
|
|
const oldPayload = JSON.parse(atob(oldToken.split('.')[1]));
|
|||
|
|
const newPayload = JSON.parse(atob(newToken.split('.')[1]));
|
|||
|
|
|
|||
|
|
// Compare by user ID - different user means new auth session
|
|||
|
|
// If user_id is same, it's just a token refresh, no need to reconnect
|
|||
|
|
return oldPayload.user_id !== newPayload.user_id ||
|
|||
|
|
oldPayload.sub !== newPayload.sub;
|
|||
|
|
} catch (e) {
|
|||
|
|
console.warn('Failed to parse token for session comparison:', e);
|
|||
|
|
// On parse error, don't reconnect (assume same session)
|
|||
|
|
return false;
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// Set up periodic ping and check for auth session changes
|
|||
|
|
const heartbeatInterval = setInterval(async () => {
|
|||
|
|
if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) {
|
|||
|
|
try {
|
|||
|
|
// Check token validity (this may refresh if needed)
|
|||
|
|
const currentToken = await apiClient.ensureValidToken();
|
|||
|
|
|
|||
|
|
// Only reconnect if user changed (new auth session)
|
|||
|
|
if (currentToken && effectiveToken && isNewAuthSession(effectiveToken, currentToken)) {
|
|||
|
|
console.log('🔄 Auth session changed (different user) - reconnecting WebSocket');
|
|||
|
|
ws?.close(1000, 'Auth session changed - reconnecting');
|
|||
|
|
clearInterval(heartbeatInterval);
|
|||
|
|
return;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Token may have been refreshed but it's the same user - continue
|
|||
|
|
if (currentToken && currentToken !== effectiveToken) {
|
|||
|
|
console.log('ℹ️ Token refreshed (same user) - updating reference');
|
|||
|
|
effectiveToken = currentToken;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Send ping
|
|||
|
|
ws?.send('ping');
|
|||
|
|
console.log('💓 Sent ping to server');
|
|||
|
|
} catch (e) {
|
|||
|
|
console.warn('Failed to send ping or validate token:', e);
|
|||
|
|
clearInterval(heartbeatInterval);
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
clearInterval(heartbeatInterval);
|
|||
|
|
}
|
|||
|
|
}, WEBSOCKET_HEARTBEAT_INTERVAL_MS); // Check for auth changes and send ping
|
|||
|
|
|
|||
|
|
// Store interval for cleanup
|
|||
|
|
(ws as any).heartbeatInterval = heartbeatInterval;
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
ws.onmessage = (event) => {
|
|||
|
|
try {
|
|||
|
|
// Handle non-JSON messages (like pong responses)
|
|||
|
|
if (typeof event.data === 'string' && event.data === 'pong') {
|
|||
|
|
console.log('🏓 Pong received from server');
|
|||
|
|
return;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
const message = JSON.parse(event.data);
|
|||
|
|
|
|||
|
|
console.log('🔔 Training WebSocket message received:', message);
|
|||
|
|
|
|||
|
|
// Handle initial state message to restore the latest known state
|
|||
|
|
if (message.type === 'initial_state') {
|
|||
|
|
console.log('📥 Received initial state:', message.data);
|
|||
|
|
const initialData = message.data;
|
|||
|
|
const initialEventData = initialData.data || {};
|
|||
|
|
let initialProgress = initialEventData.progress || 0;
|
|||
|
|
|
|||
|
|
// Calculate progress for product_completed events
|
|||
|
|
if (initialData.type === 'product_completed') {
|
|||
|
|
const productsCompleted = initialEventData.products_completed || 0;
|
|||
|
|
const totalProducts = initialEventData.total_products || 1;
|
|||
|
|
const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
|
|||
|
|
initialProgress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
|
|||
|
|
console.log('📦 Product training completed in initial state',
|
|||
|
|
`${productsCompleted}/${totalProducts}`,
|
|||
|
|
`progress: ${initialProgress}%`);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Update job status in cache with initial state
|
|||
|
|
queryClient.setQueryData(
|
|||
|
|
trainingKeys.jobs.status(tenantId, jobId),
|
|||
|
|
(oldData: TrainingJobStatus | undefined) => ({
|
|||
|
|
...oldData,
|
|||
|
|
job_id: jobId,
|
|||
|
|
status: initialData.type === 'completed' ? 'completed' :
|
|||
|
|
initialData.type === 'failed' ? 'failed' :
|
|||
|
|
initialData.type === 'started' ? 'running' :
|
|||
|
|
initialData.type === 'progress' ? 'running' :
|
|||
|
|
initialData.type === 'product_completed' ? 'running' :
|
|||
|
|
initialData.type === 'step_completed' ? 'running' :
|
|||
|
|
oldData?.status || 'running',
|
|||
|
|
progress: typeof initialProgress === 'number' ? initialProgress : oldData?.progress || 0,
|
|||
|
|
current_step: initialEventData.current_step || initialEventData.step_name || oldData?.current_step,
|
|||
|
|
})
|
|||
|
|
);
|
|||
|
|
return; // Initial state messages are only for state restoration, don't process as regular events
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Extract data from backend message structure
|
|||
|
|
const eventData = message.data || {};
|
|||
|
|
let progress = eventData.progress || 0;
|
|||
|
|
const currentStep = eventData.current_step || eventData.step_name || '';
|
|||
|
|
const stepDetails = eventData.step_details || '';
|
|||
|
|
|
|||
|
|
// Handle product_completed events - calculate progress dynamically
|
|||
|
|
if (message.type === 'product_completed') {
|
|||
|
|
const productsCompleted = eventData.products_completed || 0;
|
|||
|
|
const totalProducts = eventData.total_products || 1;
|
|||
|
|
|
|||
|
|
// Calculate progress: DATA_ANALYSIS% base + (completed/total * (TRAINING_RANGE_END - DATA_ANALYSIS)%)
|
|||
|
|
const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
|
|||
|
|
progress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
|
|||
|
|
|
|||
|
|
console.log('📦 Product training completed',
|
|||
|
|
`${productsCompleted}/${totalProducts}`,
|
|||
|
|
`progress: ${progress}%`);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Update job status in cache
|
|||
|
|
queryClient.setQueryData(
|
|||
|
|
trainingKeys.jobs.status(tenantId, jobId),
|
|||
|
|
(oldData: TrainingJobStatus | undefined) => ({
|
|||
|
|
...oldData,
|
|||
|
|
job_id: jobId,
|
|||
|
|
status: message.type === 'completed' ? 'completed' :
|
|||
|
|
message.type === 'failed' ? 'failed' :
|
|||
|
|
message.type === 'started' ? 'running' :
|
|||
|
|
oldData?.status || 'running',
|
|||
|
|
progress: typeof progress === 'number' ? progress : oldData?.progress || 0,
|
|||
|
|
current_step: currentStep || oldData?.current_step,
|
|||
|
|
})
|
|||
|
|
);
|
|||
|
|
|
|||
|
|
// Call appropriate callback based on message type
|
|||
|
|
switch (message.type) {
|
|||
|
|
case 'connected':
|
|||
|
|
console.log('🔗 WebSocket connected');
|
|||
|
|
break;
|
|||
|
|
|
|||
|
|
case 'started':
|
|||
|
|
console.log('🚀 Training started');
|
|||
|
|
memoizedOptions?.onStarted?.(message);
|
|||
|
|
break;
|
|||
|
|
|
|||
|
|
case 'progress':
|
|||
|
|
console.log('📊 Training progress update', `${progress}%`);
|
|||
|
|
memoizedOptions?.onProgress?.(message);
|
|||
|
|
break;
|
|||
|
|
|
|||
|
|
case 'product_completed':
|
|||
|
|
console.log('✅ Product training completed');
|
|||
|
|
// Treat as progress update
|
|||
|
|
memoizedOptions?.onProgress?.({
|
|||
|
|
...message,
|
|||
|
|
data: {
|
|||
|
|
...eventData,
|
|||
|
|
progress, // Use calculated progress
|
|||
|
|
}
|
|||
|
|
});
|
|||
|
|
break;
|
|||
|
|
|
|||
|
|
case 'step_completed':
|
|||
|
|
console.log('📋 Step completed');
|
|||
|
|
memoizedOptions?.onProgress?.(message);
|
|||
|
|
break;
|
|||
|
|
|
|||
|
|
case 'completed':
|
|||
|
|
console.log('✅ Training completed successfully');
|
|||
|
|
memoizedOptions?.onCompleted?.(message);
|
|||
|
|
// Invalidate models and statistics
|
|||
|
|
queryClient.invalidateQueries({ queryKey: trainingKeys.models.all() });
|
|||
|
|
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
|
|||
|
|
isManuallyDisconnected = true;
|
|||
|
|
break;
|
|||
|
|
|
|||
|
|
case 'failed':
|
|||
|
|
console.log('❌ Training failed');
|
|||
|
|
memoizedOptions?.onError?.(message);
|
|||
|
|
isManuallyDisconnected = true;
|
|||
|
|
break;
|
|||
|
|
|
|||
|
|
default:
|
|||
|
|
console.log(`🔍 Unknown message type: ${message.type}`);
|
|||
|
|
break;
|
|||
|
|
}
|
|||
|
|
} catch (error) {
|
|||
|
|
console.error('Error parsing WebSocket message:', error);
|
|||
|
|
setConnectionError('Error parsing message from server');
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
ws.onerror = (error) => {
|
|||
|
|
console.error('Training WebSocket error:', error);
|
|||
|
|
setConnectionError('WebSocket connection error');
|
|||
|
|
setIsConnected(false);
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
ws.onclose = (event) => {
|
|||
|
|
console.log(`❌ Training WebSocket disconnected. Code: ${event.code}, Reason: "${event.reason}"`, {
|
|||
|
|
wasClean: event.wasClean,
|
|||
|
|
jobId,
|
|||
|
|
timeConnected: ws ? `${Date.now() - (ws as any)._connectTime || 0}ms` : 'unknown',
|
|||
|
|
reconnectAttempts
|
|||
|
|
});
|
|||
|
|
setIsConnected(false);
|
|||
|
|
|
|||
|
|
// Detailed logging for different close codes
|
|||
|
|
switch (event.code) {
|
|||
|
|
case 1000:
|
|||
|
|
if (event.reason === 'Auth session changed - reconnecting') {
|
|||
|
|
console.log('🔄 WebSocket closed for auth session change - will reconnect immediately');
|
|||
|
|
} else {
|
|||
|
|
console.log('🔒 WebSocket closed normally');
|
|||
|
|
}
|
|||
|
|
break;
|
|||
|
|
case 1006:
|
|||
|
|
console.log('⚠️ WebSocket closed abnormally (1006) - likely server-side issue or network problem');
|
|||
|
|
break;
|
|||
|
|
case 1001:
|
|||
|
|
console.log('🔄 WebSocket endpoint going away');
|
|||
|
|
break;
|
|||
|
|
case 1003:
|
|||
|
|
console.log('❌ WebSocket unsupported data received');
|
|||
|
|
break;
|
|||
|
|
default:
|
|||
|
|
console.log(`❓ WebSocket closed with code ${event.code}`);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Handle auth session change reconnection (immediate reconnect)
|
|||
|
|
if (event.code === 1000 && event.reason === 'Auth session changed - reconnecting') {
|
|||
|
|
console.log('🔄 Reconnecting immediately due to auth session change...');
|
|||
|
|
reconnectTimer = setTimeout(() => {
|
|||
|
|
connect(); // Reconnect immediately with new session token
|
|||
|
|
}, WEBSOCKET_RECONNECT_INITIAL_DELAY_MS); // Short delay to allow cleanup
|
|||
|
|
return;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Try to reconnect if not manually disconnected and haven't exceeded max attempts
|
|||
|
|
if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) {
|
|||
|
|
const delay = Math.min(
|
|||
|
|
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS * Math.pow(2, reconnectAttempts),
|
|||
|
|
WEBSOCKET_RECONNECT_MAX_DELAY_MS
|
|||
|
|
); // Exponential backoff
|
|||
|
|
console.log(`🔄 Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`);
|
|||
|
|
|
|||
|
|
reconnectTimer = setTimeout(() => {
|
|||
|
|
reconnectAttempts++;
|
|||
|
|
connect();
|
|||
|
|
}, delay);
|
|||
|
|
} else if (reconnectAttempts >= maxReconnectAttempts) {
|
|||
|
|
console.log(`❌ Max reconnection attempts (${maxReconnectAttempts}) reached. Giving up.`);
|
|||
|
|
setConnectionError(`Connection failed after ${maxReconnectAttempts} attempts. The training job may not exist or the server may be unavailable.`);
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
} catch (error) {
|
|||
|
|
console.error('Error creating WebSocket connection:', error);
|
|||
|
|
setConnectionError('Failed to create WebSocket connection');
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// Connect immediately to avoid missing early progress updates
|
|||
|
|
console.log('🚀 Starting immediate WebSocket connection...');
|
|||
|
|
connect();
|
|||
|
|
|
|||
|
|
// Cleanup function
|
|||
|
|
return () => {
|
|||
|
|
isManuallyDisconnected = true;
|
|||
|
|
|
|||
|
|
if (reconnectTimer) {
|
|||
|
|
clearTimeout(reconnectTimer);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if (ws) {
|
|||
|
|
ws.close(1000, 'Component unmounted');
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
setIsConnected(false);
|
|||
|
|
};
|
|||
|
|
}, [tenantId, jobId, queryClient, memoizedOptions]);
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
isConnected,
|
|||
|
|
connectionError
|
|||
|
|
};
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// Utility Hooks
|
|||
|
|
export const useIsTrainingInProgress = (
|
|||
|
|
tenantId: string,
|
|||
|
|
jobId?: string,
|
|||
|
|
isWebSocketConnected?: boolean
|
|||
|
|
) => {
|
|||
|
|
const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', {
|
|||
|
|
enabled: !!jobId,
|
|||
|
|
isWebSocketConnected,
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
return jobStatus?.status === 'running' || jobStatus?.status === 'pending';
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
export const useTrainingProgress = (
|
|||
|
|
tenantId: string,
|
|||
|
|
jobId?: string,
|
|||
|
|
isWebSocketConnected?: boolean
|
|||
|
|
) => {
|
|||
|
|
const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', {
|
|||
|
|
enabled: !!jobId,
|
|||
|
|
isWebSocketConnected,
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
progress: jobStatus?.progress || 0,
|
|||
|
|
currentStep: jobStatus?.current_step,
|
|||
|
|
isComplete: jobStatus?.status === 'completed',
|
|||
|
|
isFailed: jobStatus?.status === 'failed',
|
|||
|
|
isRunning: jobStatus?.status === 'running',
|
|||
|
|
};
|
|||
|
|
};
|