/** * Training React Query hooks * Provides data fetching, caching, and state management for training operations */ import * as React from 'react'; import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOptions } from '@tanstack/react-query'; import { trainingService } from '../services/training'; import { ApiError, apiClient } from '../client/apiClient'; import { useAuthStore } from '../../stores/auth.store'; import { HTTP_POLLING_INTERVAL_MS, HTTP_POLLING_DEBOUNCE_MS, WEBSOCKET_HEARTBEAT_INTERVAL_MS, WEBSOCKET_MAX_RECONNECT_ATTEMPTS, WEBSOCKET_RECONNECT_INITIAL_DELAY_MS, WEBSOCKET_RECONNECT_MAX_DELAY_MS, PROGRESS_DATA_ANALYSIS, PROGRESS_TRAINING_RANGE_START, PROGRESS_TRAINING_RANGE_END } from '../../constants/training'; import type { TrainingJobRequest, TrainingJobResponse, TrainingJobStatus, SingleProductTrainingRequest, ModelMetricsResponse, TrainedModelResponse, } from '../types/training'; // Query Keys Factory export const trainingKeys = { all: ['training'] as const, jobs: { all: () => [...trainingKeys.all, 'jobs'] as const, status: (tenantId: string, jobId: string) => [...trainingKeys.jobs.all(), 'status', tenantId, jobId] as const, }, models: { all: () => [...trainingKeys.all, 'models'] as const, lists: () => [...trainingKeys.models.all(), 'list'] as const, list: (tenantId: string, params?: any) => [...trainingKeys.models.lists(), tenantId, params] as const, details: () => [...trainingKeys.models.all(), 'detail'] as const, detail: (tenantId: string, modelId: string) => [...trainingKeys.models.details(), tenantId, modelId] as const, active: (tenantId: string, inventoryProductId: string) => [...trainingKeys.models.all(), 'active', tenantId, inventoryProductId] as const, metrics: (tenantId: string, modelId: string) => [...trainingKeys.models.all(), 'metrics', tenantId, modelId] as const, performance: (tenantId: string, modelId: string) => [...trainingKeys.models.all(), 'performance', tenantId, modelId] as const, }, statistics: (tenantId: string) => [...trainingKeys.all, 'statistics', tenantId] as const, } as const; // Training Job Queries export const useTrainingJobStatus = ( tenantId: string, jobId: string, options?: Omit, 'queryKey' | 'queryFn'> & { isWebSocketConnected?: boolean; } ) => { const { isWebSocketConnected, ...queryOptions } = options || {}; const [enablePolling, setEnablePolling] = React.useState(false); // Debounce HTTP polling activation: wait after WebSocket disconnects // This prevents race conditions where both WebSocket and HTTP are briefly active React.useEffect(() => { if (!isWebSocketConnected) { const debounceTimer = setTimeout(() => { setEnablePolling(true); console.log(`🔄 HTTP polling enabled after ${HTTP_POLLING_DEBOUNCE_MS}ms debounce (WebSocket disconnected)`); }, HTTP_POLLING_DEBOUNCE_MS); return () => clearTimeout(debounceTimer); } else { setEnablePolling(false); console.log('❌ HTTP polling disabled (WebSocket connected)'); } }, [isWebSocketConnected]); // Completely disable the query when WebSocket is connected or during debounce period const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected && enablePolling; console.log('🔄 Training status query:', { tenantId: !!tenantId, jobId: !!jobId, isWebSocketConnected, enablePolling, queryEnabled: isEnabled }); return useQuery({ queryKey: trainingKeys.jobs.status(tenantId, jobId), queryFn: () => { console.log('📡 Executing HTTP training status query (WebSocket disconnected)'); return trainingService.getTrainingJobStatus(tenantId, jobId); }, enabled: isEnabled, // Completely disable when WebSocket connected refetchInterval: isEnabled ? (query) => { // Only set up refetch interval if the query is enabled const data = query.state.data; // Stop polling if we get auth errors or training is completed if (query.state.error && (query.state.error as any)?.status === 401) { console.log('đŸšĢ Stopping status polling due to auth error'); return false; } if (data?.status === 'completed' || data?.status === 'failed') { console.log('🏁 Training completed - stopping HTTP polling'); return false; // Stop polling when training is done } console.log(`📊 HTTP fallback polling active (WebSocket disconnected) - ${HTTP_POLLING_INTERVAL_MS}ms interval`); return HTTP_POLLING_INTERVAL_MS; // Poll while training (fallback when WebSocket unavailable) } : false, // Completely disable interval when WebSocket connected staleTime: 1000, // Consider data stale after 1 second retry: (failureCount, error) => { // Don't retry on auth errors if ((error as any)?.status === 401) { console.log('đŸšĢ Not retrying due to auth error'); return false; } return failureCount < 3; }, ...queryOptions, }); }; // Model Queries export const useActiveModel = ( tenantId: string, inventoryProductId: string, options?: Omit, 'queryKey' | 'queryFn'> ) => { return useQuery({ queryKey: trainingKeys.models.active(tenantId, inventoryProductId), queryFn: () => trainingService.getActiveModel(tenantId, inventoryProductId), enabled: !!tenantId && !!inventoryProductId, staleTime: 5 * 60 * 1000, // 5 minutes ...options, }); }; export const useModels = ( tenantId: string, queryParams?: any, options?: Omit, 'queryKey' | 'queryFn'> ) => { return useQuery({ queryKey: trainingKeys.models.list(tenantId, queryParams), queryFn: () => trainingService.getModels(tenantId, queryParams), enabled: !!tenantId, staleTime: 2 * 60 * 1000, // 2 minutes ...options, }); }; export const useModelMetrics = ( tenantId: string, modelId: string, options?: Omit, 'queryKey' | 'queryFn'> ) => { return useQuery({ queryKey: trainingKeys.models.metrics(tenantId, modelId), queryFn: () => trainingService.getModelMetrics(tenantId, modelId), enabled: !!tenantId && !!modelId, staleTime: 10 * 60 * 1000, // 10 minutes ...options, }); }; export const useModelPerformance = ( tenantId: string, modelId: string, options?: Omit, 'queryKey' | 'queryFn'> ) => { return useQuery({ queryKey: trainingKeys.models.performance(tenantId, modelId), queryFn: () => trainingService.getModelPerformance(tenantId, modelId), enabled: !!tenantId && !!modelId, staleTime: 10 * 60 * 1000, // 10 minutes ...options, }); }; // Statistics Queries export const useTenantTrainingStatistics = ( tenantId: string, options?: Omit, 'queryKey' | 'queryFn'> ) => { return useQuery({ queryKey: trainingKeys.statistics(tenantId), queryFn: () => trainingService.getTenantStatistics(tenantId), enabled: !!tenantId, staleTime: 5 * 60 * 1000, // 5 minutes ...options, }); }; // Training Job Mutations export const useCreateTrainingJob = ( options?: UseMutationOptions< TrainingJobResponse, ApiError, { tenantId: string; request: TrainingJobRequest } > ) => { const queryClient = useQueryClient(); return useMutation< TrainingJobResponse, ApiError, { tenantId: string; request: TrainingJobRequest } >({ mutationFn: ({ tenantId, request }) => trainingService.createTrainingJob(tenantId, request), onSuccess: (data, { tenantId }) => { // Add the job status to cache queryClient.setQueryData( trainingKeys.jobs.status(tenantId, data.job_id), { job_id: data.job_id, status: data.status, progress: 0, } ); // Invalidate statistics to reflect the new training job queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) }); }, ...options, }); }; export const useTrainSingleProduct = ( options?: UseMutationOptions< TrainingJobResponse, ApiError, { tenantId: string; inventoryProductId: string; request: SingleProductTrainingRequest } > ) => { const queryClient = useQueryClient(); return useMutation< TrainingJobResponse, ApiError, { tenantId: string; inventoryProductId: string; request: SingleProductTrainingRequest } >({ mutationFn: ({ tenantId, inventoryProductId, request }) => trainingService.trainSingleProduct(tenantId, inventoryProductId, request), onSuccess: (data, { tenantId, inventoryProductId }) => { // Add the job status to cache queryClient.setQueryData( trainingKeys.jobs.status(tenantId, data.job_id), { job_id: data.job_id, status: data.status, progress: 0, } ); // Invalidate active model for this product queryClient.invalidateQueries({ queryKey: trainingKeys.models.active(tenantId, inventoryProductId) }); // Invalidate statistics queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) }); }, ...options, }); }; // Admin Mutations export const useDeleteAllTenantModels = ( options?: UseMutationOptions<{ message: string }, ApiError, { tenantId: string }> ) => { const queryClient = useQueryClient(); return useMutation<{ message: string }, ApiError, { tenantId: string }>({ mutationFn: ({ tenantId }) => trainingService.deleteAllTenantModels(tenantId), onSuccess: (_, { tenantId }) => { // Invalidate all model-related queries for this tenant queryClient.invalidateQueries({ queryKey: trainingKeys.models.all() }); queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) }); }, ...options, }); }; // WebSocket Hook for Real-time Training Updates export const useTrainingWebSocket = ( tenantId: string, jobId: string, token?: string, options?: { onProgress?: (data: any) => void; onCompleted?: (data: any) => void; onError?: (error: any) => void; onStarted?: (data: any) => void; onCancelled?: (data: any) => void; } ) => { const queryClient = useQueryClient(); const [isConnected, setIsConnected] = React.useState(false); const [connectionError, setConnectionError] = React.useState(null); const [connectionAttempts, setConnectionAttempts] = React.useState(0); // Memoize options to prevent unnecessary effect re-runs const memoizedOptions = React.useMemo(() => options, [ options?.onProgress, options?.onCompleted, options?.onError, options?.onStarted, options?.onCancelled ]); React.useEffect(() => { if (!tenantId || !jobId || !memoizedOptions) { return; } let ws: WebSocket | null = null; let reconnectTimer: NodeJS.Timeout | null = null; let isManuallyDisconnected = false; let reconnectAttempts = 0; const maxReconnectAttempts = WEBSOCKET_MAX_RECONNECT_ATTEMPTS; const connect = async () => { try { setConnectionError(null); setConnectionAttempts(prev => prev + 1); // Use centralized token management from apiClient let effectiveToken: string | null; try { // Always use the apiClient's token management effectiveToken = await apiClient.ensureValidToken(); if (!effectiveToken) { throw new Error('No valid token available'); } } catch (error) { console.error('❌ Failed to get valid token for WebSocket:', error); setConnectionError('Authentication failed. Please log in again.'); return; } console.log(`🔄 Attempting WebSocket connection (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts + 1}):`, { tenantId, jobId, hasToken: !!effectiveToken, tokenFromApiClient: true }); ws = trainingService.createWebSocketConnection(tenantId, jobId, effectiveToken); ws.onopen = () => { console.log('✅ Training WebSocket connected successfully', { readyState: ws?.readyState, url: ws?.url, jobId }); // Track connection time for debugging (ws as any)._connectTime = Date.now(); setIsConnected(true); reconnectAttempts = 0; // Reset on successful connection // Request current status on connection try { ws?.send('get_status'); console.log('📤 Requested current training status'); } catch (e) { console.warn('Failed to request status on connection:', e); } // Helper function to check if tokens represent different auth users/sessions const isNewAuthSession = (oldToken: string, newToken: string): boolean => { if (!oldToken || !newToken) return !!oldToken !== !!newToken; try { const oldPayload = JSON.parse(atob(oldToken.split('.')[1])); const newPayload = JSON.parse(atob(newToken.split('.')[1])); // Compare by user ID - different user means new auth session // If user_id is same, it's just a token refresh, no need to reconnect return oldPayload.user_id !== newPayload.user_id || oldPayload.sub !== newPayload.sub; } catch (e) { console.warn('Failed to parse token for session comparison:', e); // On parse error, don't reconnect (assume same session) return false; } }; // Set up periodic ping and check for auth session changes const heartbeatInterval = setInterval(async () => { if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) { try { // Check token validity (this may refresh if needed) const currentToken = await apiClient.ensureValidToken(); // Only reconnect if user changed (new auth session) if (currentToken && effectiveToken && isNewAuthSession(effectiveToken, currentToken)) { console.log('🔄 Auth session changed (different user) - reconnecting WebSocket'); ws?.close(1000, 'Auth session changed - reconnecting'); clearInterval(heartbeatInterval); return; } // Token may have been refreshed but it's the same user - continue if (currentToken && currentToken !== effectiveToken) { console.log('â„šī¸ Token refreshed (same user) - updating reference'); effectiveToken = currentToken; } // Send ping ws?.send('ping'); console.log('💓 Sent ping to server'); } catch (e) { console.warn('Failed to send ping or validate token:', e); clearInterval(heartbeatInterval); } } else { clearInterval(heartbeatInterval); } }, WEBSOCKET_HEARTBEAT_INTERVAL_MS); // Check for auth changes and send ping // Store interval for cleanup (ws as any).heartbeatInterval = heartbeatInterval; }; ws.onmessage = (event) => { try { // Handle non-JSON messages (like pong responses) if (typeof event.data === 'string' && event.data === 'pong') { console.log('🏓 Pong received from server'); return; } const message = JSON.parse(event.data); console.log('🔔 Training WebSocket message received:', message); // Handle initial state message to restore the latest known state if (message.type === 'initial_state') { console.log('đŸ“Ĩ Received initial state:', message.data); const initialData = message.data; const initialEventData = initialData.data || {}; let initialProgress = initialEventData.progress || 0; // Calculate progress for product_completed events if (initialData.type === 'product_completed') { const productsCompleted = initialEventData.products_completed || 0; const totalProducts = initialEventData.total_products || 1; const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS; initialProgress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth); console.log('đŸ“Ļ Product training completed in initial state', `${productsCompleted}/${totalProducts}`, `progress: ${initialProgress}%`); } // Update job status in cache with initial state queryClient.setQueryData( trainingKeys.jobs.status(tenantId, jobId), (oldData: TrainingJobStatus | undefined) => ({ ...oldData, job_id: jobId, status: initialData.type === 'completed' ? 'completed' : initialData.type === 'failed' ? 'failed' : initialData.type === 'started' ? 'running' : initialData.type === 'progress' ? 'running' : initialData.type === 'product_completed' ? 'running' : initialData.type === 'step_completed' ? 'running' : oldData?.status || 'running', progress: typeof initialProgress === 'number' ? initialProgress : oldData?.progress || 0, current_step: initialEventData.current_step || initialEventData.step_name || oldData?.current_step, }) ); return; // Initial state messages are only for state restoration, don't process as regular events } // Extract data from backend message structure const eventData = message.data || {}; let progress = eventData.progress || 0; const currentStep = eventData.current_step || eventData.step_name || ''; const stepDetails = eventData.step_details || ''; // Handle product_completed events - calculate progress dynamically if (message.type === 'product_completed') { const productsCompleted = eventData.products_completed || 0; const totalProducts = eventData.total_products || 1; // Calculate progress: DATA_ANALYSIS% base + (completed/total * (TRAINING_RANGE_END - DATA_ANALYSIS)%) const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS; progress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth); console.log('đŸ“Ļ Product training completed', `${productsCompleted}/${totalProducts}`, `progress: ${progress}%`); } // Update job status in cache queryClient.setQueryData( trainingKeys.jobs.status(tenantId, jobId), (oldData: TrainingJobStatus | undefined) => ({ ...oldData, job_id: jobId, status: message.type === 'completed' ? 'completed' : message.type === 'failed' ? 'failed' : message.type === 'started' ? 'running' : oldData?.status || 'running', progress: typeof progress === 'number' ? progress : oldData?.progress || 0, current_step: currentStep || oldData?.current_step, }) ); // Call appropriate callback based on message type switch (message.type) { case 'connected': console.log('🔗 WebSocket connected'); break; case 'started': console.log('🚀 Training started'); memoizedOptions?.onStarted?.(message); break; case 'progress': console.log('📊 Training progress update', `${progress}%`); memoizedOptions?.onProgress?.(message); break; case 'product_completed': console.log('✅ Product training completed'); // Treat as progress update memoizedOptions?.onProgress?.({ ...message, data: { ...eventData, progress, // Use calculated progress } }); break; case 'step_completed': console.log('📋 Step completed'); memoizedOptions?.onProgress?.(message); break; case 'completed': console.log('✅ Training completed successfully'); memoizedOptions?.onCompleted?.(message); // Invalidate models and statistics queryClient.invalidateQueries({ queryKey: trainingKeys.models.all() }); queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) }); isManuallyDisconnected = true; break; case 'failed': console.log('❌ Training failed'); memoizedOptions?.onError?.(message); isManuallyDisconnected = true; break; default: console.log(`🔍 Unknown message type: ${message.type}`); break; } } catch (error) { console.error('Error parsing WebSocket message:', error); setConnectionError('Error parsing message from server'); } }; ws.onerror = (error) => { console.error('Training WebSocket error:', error); setConnectionError('WebSocket connection error'); setIsConnected(false); }; ws.onclose = (event) => { console.log(`❌ Training WebSocket disconnected. Code: ${event.code}, Reason: "${event.reason}"`, { wasClean: event.wasClean, jobId, timeConnected: ws ? `${Date.now() - (ws as any)._connectTime || 0}ms` : 'unknown', reconnectAttempts }); setIsConnected(false); // Detailed logging for different close codes switch (event.code) { case 1000: if (event.reason === 'Auth session changed - reconnecting') { console.log('🔄 WebSocket closed for auth session change - will reconnect immediately'); } else { console.log('🔒 WebSocket closed normally'); } break; case 1006: console.log('âš ī¸ WebSocket closed abnormally (1006) - likely server-side issue or network problem'); break; case 1001: console.log('🔄 WebSocket endpoint going away'); break; case 1003: console.log('❌ WebSocket unsupported data received'); break; default: console.log(`❓ WebSocket closed with code ${event.code}`); } // Handle auth session change reconnection (immediate reconnect) if (event.code === 1000 && event.reason === 'Auth session changed - reconnecting') { console.log('🔄 Reconnecting immediately due to auth session change...'); reconnectTimer = setTimeout(() => { connect(); // Reconnect immediately with new session token }, WEBSOCKET_RECONNECT_INITIAL_DELAY_MS); // Short delay to allow cleanup return; } // Try to reconnect if not manually disconnected and haven't exceeded max attempts if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) { const delay = Math.min( WEBSOCKET_RECONNECT_INITIAL_DELAY_MS * Math.pow(2, reconnectAttempts), WEBSOCKET_RECONNECT_MAX_DELAY_MS ); // Exponential backoff console.log(`🔄 Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`); reconnectTimer = setTimeout(() => { reconnectAttempts++; connect(); }, delay); } else if (reconnectAttempts >= maxReconnectAttempts) { console.log(`❌ Max reconnection attempts (${maxReconnectAttempts}) reached. Giving up.`); setConnectionError(`Connection failed after ${maxReconnectAttempts} attempts. The training job may not exist or the server may be unavailable.`); } }; } catch (error) { console.error('Error creating WebSocket connection:', error); setConnectionError('Failed to create WebSocket connection'); } }; // Connect immediately to avoid missing early progress updates console.log('🚀 Starting immediate WebSocket connection...'); connect(); // Cleanup function return () => { isManuallyDisconnected = true; if (reconnectTimer) { clearTimeout(reconnectTimer); } if (ws) { ws.close(1000, 'Component unmounted'); } setIsConnected(false); }; }, [tenantId, jobId, queryClient, memoizedOptions]); return { isConnected, connectionError }; }; // Utility Hooks export const useIsTrainingInProgress = ( tenantId: string, jobId?: string, isWebSocketConnected?: boolean ) => { const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', { enabled: !!jobId, isWebSocketConnected, }); return jobStatus?.status === 'running' || jobStatus?.status === 'pending'; }; export const useTrainingProgress = ( tenantId: string, jobId?: string, isWebSocketConnected?: boolean ) => { const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', { enabled: !!jobId, isWebSocketConnected, }); return { progress: jobStatus?.progress || 0, currentStep: jobStatus?.current_step, isComplete: jobStatus?.status === 'completed', isFailed: jobStatus?.status === 'failed', isRunning: jobStatus?.status === 'running', }; };