Files
bakery-ia/frontend/src/api/hooks/training.ts

708 lines
26 KiB
TypeScript
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* Training React Query hooks
* Provides data fetching, caching, and state management for training operations
*/
import * as React from 'react';
import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOptions } from '@tanstack/react-query';
import { trainingService } from '../services/training';
import { ApiError, apiClient } from '../client/apiClient';
import { useAuthStore } from '../../stores/auth.store';
import {
HTTP_POLLING_INTERVAL_MS,
HTTP_POLLING_DEBOUNCE_MS,
WEBSOCKET_HEARTBEAT_INTERVAL_MS,
WEBSOCKET_MAX_RECONNECT_ATTEMPTS,
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS,
WEBSOCKET_RECONNECT_MAX_DELAY_MS,
PROGRESS_DATA_ANALYSIS,
PROGRESS_TRAINING_RANGE_START,
PROGRESS_TRAINING_RANGE_END
} from '../../constants/training';
import type {
TrainingJobRequest,
TrainingJobResponse,
TrainingJobStatus,
SingleProductTrainingRequest,
ModelMetricsResponse,
TrainedModelResponse,
} from '../types/training';
// Query Keys Factory
export const trainingKeys = {
all: ['training'] as const,
jobs: {
all: () => [...trainingKeys.all, 'jobs'] as const,
status: (tenantId: string, jobId: string) =>
[...trainingKeys.jobs.all(), 'status', tenantId, jobId] as const,
},
models: {
all: () => [...trainingKeys.all, 'models'] as const,
lists: () => [...trainingKeys.models.all(), 'list'] as const,
list: (tenantId: string, params?: any) =>
[...trainingKeys.models.lists(), tenantId, params] as const,
details: () => [...trainingKeys.models.all(), 'detail'] as const,
detail: (tenantId: string, modelId: string) =>
[...trainingKeys.models.details(), tenantId, modelId] as const,
active: (tenantId: string, inventoryProductId: string) =>
[...trainingKeys.models.all(), 'active', tenantId, inventoryProductId] as const,
metrics: (tenantId: string, modelId: string) =>
[...trainingKeys.models.all(), 'metrics', tenantId, modelId] as const,
performance: (tenantId: string, modelId: string) =>
[...trainingKeys.models.all(), 'performance', tenantId, modelId] as const,
},
statistics: (tenantId: string) =>
[...trainingKeys.all, 'statistics', tenantId] as const,
} as const;
// Training Job Queries
export const useTrainingJobStatus = (
tenantId: string,
jobId: string,
options?: Omit<UseQueryOptions<TrainingJobStatus, ApiError>, 'queryKey' | 'queryFn'> & {
isWebSocketConnected?: boolean;
}
) => {
const { isWebSocketConnected, ...queryOptions } = options || {};
const [enablePolling, setEnablePolling] = React.useState(false);
// Debounce HTTP polling activation: wait after WebSocket disconnects
// This prevents race conditions where both WebSocket and HTTP are briefly active
React.useEffect(() => {
if (!isWebSocketConnected) {
const debounceTimer = setTimeout(() => {
setEnablePolling(true);
console.log(`🔄 HTTP polling enabled after ${HTTP_POLLING_DEBOUNCE_MS}ms debounce (WebSocket disconnected)`);
}, HTTP_POLLING_DEBOUNCE_MS);
return () => clearTimeout(debounceTimer);
} else {
setEnablePolling(false);
console.log('❌ HTTP polling disabled (WebSocket connected)');
}
}, [isWebSocketConnected]);
// Completely disable the query when WebSocket is connected or during debounce period
const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected && enablePolling;
console.log('🔄 Training status query:', {
tenantId: !!tenantId,
jobId: !!jobId,
isWebSocketConnected,
enablePolling,
queryEnabled: isEnabled
});
return useQuery<TrainingJobStatus, ApiError>({
queryKey: trainingKeys.jobs.status(tenantId, jobId),
queryFn: () => {
console.log('📡 Executing HTTP training status query (WebSocket disconnected)');
return trainingService.getTrainingJobStatus(tenantId, jobId);
},
enabled: isEnabled, // Completely disable when WebSocket connected
refetchInterval: isEnabled ? (query) => {
// Only set up refetch interval if the query is enabled
const data = query.state.data;
// Stop polling if we get auth errors or training is completed
if (query.state.error && (query.state.error as any)?.status === 401) {
console.log('🚫 Stopping status polling due to auth error');
return false;
}
if (data?.status === 'completed' || data?.status === 'failed') {
console.log('🏁 Training completed - stopping HTTP polling');
return false; // Stop polling when training is done
}
console.log(`📊 HTTP fallback polling active (WebSocket disconnected) - ${HTTP_POLLING_INTERVAL_MS}ms interval`);
return HTTP_POLLING_INTERVAL_MS; // Poll while training (fallback when WebSocket unavailable)
} : false, // Completely disable interval when WebSocket connected
staleTime: 1000, // Consider data stale after 1 second
retry: (failureCount, error) => {
// Don't retry on auth errors
if ((error as any)?.status === 401) {
console.log('🚫 Not retrying due to auth error');
return false;
}
return failureCount < 3;
},
...queryOptions,
});
};
// Model Queries
export const useActiveModel = (
tenantId: string,
inventoryProductId: string,
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
) => {
return useQuery<any, ApiError>({
queryKey: trainingKeys.models.active(tenantId, inventoryProductId),
queryFn: () => trainingService.getActiveModel(tenantId, inventoryProductId),
enabled: !!tenantId && !!inventoryProductId,
staleTime: 5 * 60 * 1000, // 5 minutes
...options,
});
};
export const useModels = (
tenantId: string,
queryParams?: any,
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
) => {
return useQuery<any, ApiError>({
queryKey: trainingKeys.models.list(tenantId, queryParams),
queryFn: () => trainingService.getModels(tenantId, queryParams),
enabled: !!tenantId,
staleTime: 2 * 60 * 1000, // 2 minutes
...options,
});
};
export const useModelMetrics = (
tenantId: string,
modelId: string,
options?: Omit<UseQueryOptions<ModelMetricsResponse, ApiError>, 'queryKey' | 'queryFn'>
) => {
return useQuery<ModelMetricsResponse, ApiError>({
queryKey: trainingKeys.models.metrics(tenantId, modelId),
queryFn: () => trainingService.getModelMetrics(tenantId, modelId),
enabled: !!tenantId && !!modelId,
staleTime: 10 * 60 * 1000, // 10 minutes
...options,
});
};
export const useModelPerformance = (
tenantId: string,
modelId: string,
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
) => {
return useQuery<any, ApiError>({
queryKey: trainingKeys.models.performance(tenantId, modelId),
queryFn: () => trainingService.getModelPerformance(tenantId, modelId),
enabled: !!tenantId && !!modelId,
staleTime: 10 * 60 * 1000, // 10 minutes
...options,
});
};
// Statistics Queries
export const useTenantTrainingStatistics = (
tenantId: string,
options?: Omit<UseQueryOptions<any, ApiError>, 'queryKey' | 'queryFn'>
) => {
return useQuery<any, ApiError>({
queryKey: trainingKeys.statistics(tenantId),
queryFn: () => trainingService.getTenantStatistics(tenantId),
enabled: !!tenantId,
staleTime: 5 * 60 * 1000, // 5 minutes
...options,
});
};
// Training Job Mutations
export const useCreateTrainingJob = (
options?: UseMutationOptions<
TrainingJobResponse,
ApiError,
{ tenantId: string; request: TrainingJobRequest }
>
) => {
const queryClient = useQueryClient();
return useMutation<
TrainingJobResponse,
ApiError,
{ tenantId: string; request: TrainingJobRequest }
>({
mutationFn: ({ tenantId, request }) => trainingService.createTrainingJob(tenantId, request),
onSuccess: (data, { tenantId }) => {
// Add the job status to cache
queryClient.setQueryData(
trainingKeys.jobs.status(tenantId, data.job_id),
{
job_id: data.job_id,
status: data.status,
progress: 0,
}
);
// Invalidate statistics to reflect the new training job
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
},
...options,
});
};
export const useTrainSingleProduct = (
options?: UseMutationOptions<
TrainingJobResponse,
ApiError,
{ tenantId: string; inventoryProductId: string; request: SingleProductTrainingRequest }
>
) => {
const queryClient = useQueryClient();
return useMutation<
TrainingJobResponse,
ApiError,
{ tenantId: string; inventoryProductId: string; request: SingleProductTrainingRequest }
>({
mutationFn: ({ tenantId, inventoryProductId, request }) =>
trainingService.trainSingleProduct(tenantId, inventoryProductId, request),
onSuccess: (data, { tenantId, inventoryProductId }) => {
// Add the job status to cache
queryClient.setQueryData(
trainingKeys.jobs.status(tenantId, data.job_id),
{
job_id: data.job_id,
status: data.status,
progress: 0,
}
);
// Invalidate active model for this product
queryClient.invalidateQueries({
queryKey: trainingKeys.models.active(tenantId, inventoryProductId)
});
// Invalidate statistics
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
},
...options,
});
};
// Admin Mutations
export const useDeleteAllTenantModels = (
options?: UseMutationOptions<{ message: string }, ApiError, { tenantId: string }>
) => {
const queryClient = useQueryClient();
return useMutation<{ message: string }, ApiError, { tenantId: string }>({
mutationFn: ({ tenantId }) => trainingService.deleteAllTenantModels(tenantId),
onSuccess: (_, { tenantId }) => {
// Invalidate all model-related queries for this tenant
queryClient.invalidateQueries({ queryKey: trainingKeys.models.all() });
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
},
...options,
});
};
// WebSocket Hook for Real-time Training Updates
export const useTrainingWebSocket = (
tenantId: string,
jobId: string,
token?: string,
options?: {
onProgress?: (data: any) => void;
onCompleted?: (data: any) => void;
onError?: (error: any) => void;
onStarted?: (data: any) => void;
onCancelled?: (data: any) => void;
}
) => {
const queryClient = useQueryClient();
const [isConnected, setIsConnected] = React.useState(false);
const [connectionError, setConnectionError] = React.useState<string | null>(null);
const [connectionAttempts, setConnectionAttempts] = React.useState(0);
// Memoize options to prevent unnecessary effect re-runs
const memoizedOptions = React.useMemo(() => options, [
options?.onProgress,
options?.onCompleted,
options?.onError,
options?.onStarted,
options?.onCancelled
]);
React.useEffect(() => {
if (!tenantId || !jobId || !memoizedOptions) {
return;
}
let ws: WebSocket | null = null;
let reconnectTimer: NodeJS.Timeout | null = null;
let isManuallyDisconnected = false;
let reconnectAttempts = 0;
const maxReconnectAttempts = WEBSOCKET_MAX_RECONNECT_ATTEMPTS;
const connect = async () => {
try {
setConnectionError(null);
setConnectionAttempts(prev => prev + 1);
// Use centralized token management from apiClient
let effectiveToken: string | null;
try {
// Always use the apiClient's token management
effectiveToken = await apiClient.ensureValidToken();
if (!effectiveToken) {
throw new Error('No valid token available');
}
} catch (error) {
console.error('❌ Failed to get valid token for WebSocket:', error);
setConnectionError('Authentication failed. Please log in again.');
return;
}
console.log(`🔄 Attempting WebSocket connection (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts + 1}):`, {
tenantId,
jobId,
hasToken: !!effectiveToken,
tokenFromApiClient: true
});
ws = trainingService.createWebSocketConnection(tenantId, jobId, effectiveToken);
ws.onopen = () => {
console.log('✅ Training WebSocket connected successfully', {
readyState: ws?.readyState,
url: ws?.url,
jobId
});
// Track connection time for debugging
(ws as any)._connectTime = Date.now();
setIsConnected(true);
reconnectAttempts = 0; // Reset on successful connection
// Request current status on connection
try {
ws?.send('get_status');
console.log('📤 Requested current training status');
} catch (e) {
console.warn('Failed to request status on connection:', e);
}
// Helper function to check if tokens represent different auth users/sessions
const isNewAuthSession = (oldToken: string, newToken: string): boolean => {
if (!oldToken || !newToken) return !!oldToken !== !!newToken;
try {
const oldPayload = JSON.parse(atob(oldToken.split('.')[1]));
const newPayload = JSON.parse(atob(newToken.split('.')[1]));
// Compare by user ID - different user means new auth session
// If user_id is same, it's just a token refresh, no need to reconnect
return oldPayload.user_id !== newPayload.user_id ||
oldPayload.sub !== newPayload.sub;
} catch (e) {
console.warn('Failed to parse token for session comparison:', e);
// On parse error, don't reconnect (assume same session)
return false;
}
};
// Set up periodic ping and check for auth session changes
const heartbeatInterval = setInterval(async () => {
if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) {
try {
// Check token validity (this may refresh if needed)
const currentToken = await apiClient.ensureValidToken();
// Only reconnect if user changed (new auth session)
if (currentToken && effectiveToken && isNewAuthSession(effectiveToken, currentToken)) {
console.log('🔄 Auth session changed (different user) - reconnecting WebSocket');
ws?.close(1000, 'Auth session changed - reconnecting');
clearInterval(heartbeatInterval);
return;
}
// Token may have been refreshed but it's the same user - continue
if (currentToken && currentToken !== effectiveToken) {
console.log(' Token refreshed (same user) - updating reference');
effectiveToken = currentToken;
}
// Send ping
ws?.send('ping');
console.log('💓 Sent ping to server');
} catch (e) {
console.warn('Failed to send ping or validate token:', e);
clearInterval(heartbeatInterval);
}
} else {
clearInterval(heartbeatInterval);
}
}, WEBSOCKET_HEARTBEAT_INTERVAL_MS); // Check for auth changes and send ping
// Store interval for cleanup
(ws as any).heartbeatInterval = heartbeatInterval;
};
ws.onmessage = (event) => {
try {
// Handle non-JSON messages (like pong responses)
if (typeof event.data === 'string' && event.data === 'pong') {
console.log('🏓 Pong received from server');
return;
}
const message = JSON.parse(event.data);
console.log('🔔 Training WebSocket message received:', message);
// Handle initial state message to restore the latest known state
if (message.type === 'initial_state') {
console.log('📥 Received initial state:', message.data);
const initialData = message.data;
const initialEventData = initialData.data || {};
let initialProgress = initialEventData.progress || 0;
// Calculate progress for product_completed events
if (initialData.type === 'product_completed') {
const productsCompleted = initialEventData.products_completed || 0;
const totalProducts = initialEventData.total_products || 1;
const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
initialProgress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
console.log('📦 Product training completed in initial state',
`${productsCompleted}/${totalProducts}`,
`progress: ${initialProgress}%`);
}
// Update job status in cache with initial state
queryClient.setQueryData(
trainingKeys.jobs.status(tenantId, jobId),
(oldData: TrainingJobStatus | undefined) => ({
...oldData,
job_id: jobId,
status: initialData.type === 'completed' ? 'completed' :
initialData.type === 'failed' ? 'failed' :
initialData.type === 'started' ? 'running' :
initialData.type === 'progress' ? 'running' :
initialData.type === 'product_completed' ? 'running' :
initialData.type === 'step_completed' ? 'running' :
oldData?.status || 'running',
progress: typeof initialProgress === 'number' ? initialProgress : oldData?.progress || 0,
current_step: initialEventData.current_step || initialEventData.step_name || oldData?.current_step,
})
);
return; // Initial state messages are only for state restoration, don't process as regular events
}
// Extract data from backend message structure
const eventData = message.data || {};
let progress = eventData.progress || 0;
const currentStep = eventData.current_step || eventData.step_name || '';
const stepDetails = eventData.step_details || '';
// Handle product_completed events - calculate progress dynamically
if (message.type === 'product_completed') {
const productsCompleted = eventData.products_completed || 0;
const totalProducts = eventData.total_products || 1;
// Calculate progress: DATA_ANALYSIS% base + (completed/total * (TRAINING_RANGE_END - DATA_ANALYSIS)%)
const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
progress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
console.log('📦 Product training completed',
`${productsCompleted}/${totalProducts}`,
`progress: ${progress}%`);
}
// Update job status in cache
queryClient.setQueryData(
trainingKeys.jobs.status(tenantId, jobId),
(oldData: TrainingJobStatus | undefined) => ({
...oldData,
job_id: jobId,
status: message.type === 'completed' ? 'completed' :
message.type === 'failed' ? 'failed' :
message.type === 'started' ? 'running' :
oldData?.status || 'running',
progress: typeof progress === 'number' ? progress : oldData?.progress || 0,
current_step: currentStep || oldData?.current_step,
})
);
// Call appropriate callback based on message type
switch (message.type) {
case 'connected':
console.log('🔗 WebSocket connected');
break;
case 'started':
console.log('🚀 Training started');
memoizedOptions?.onStarted?.(message);
break;
case 'progress':
console.log('📊 Training progress update', `${progress}%`);
memoizedOptions?.onProgress?.(message);
break;
case 'product_completed':
console.log('✅ Product training completed');
// Treat as progress update
memoizedOptions?.onProgress?.({
...message,
data: {
...eventData,
progress, // Use calculated progress
}
});
break;
case 'step_completed':
console.log('📋 Step completed');
memoizedOptions?.onProgress?.(message);
break;
case 'completed':
console.log('✅ Training completed successfully');
memoizedOptions?.onCompleted?.(message);
// Invalidate models and statistics
queryClient.invalidateQueries({ queryKey: trainingKeys.models.all() });
queryClient.invalidateQueries({ queryKey: trainingKeys.statistics(tenantId) });
isManuallyDisconnected = true;
break;
case 'failed':
console.log('❌ Training failed');
memoizedOptions?.onError?.(message);
isManuallyDisconnected = true;
break;
default:
console.log(`🔍 Unknown message type: ${message.type}`);
break;
}
} catch (error) {
console.error('Error parsing WebSocket message:', error);
setConnectionError('Error parsing message from server');
}
};
ws.onerror = (error) => {
console.error('Training WebSocket error:', error);
setConnectionError('WebSocket connection error');
setIsConnected(false);
};
ws.onclose = (event) => {
console.log(`❌ Training WebSocket disconnected. Code: ${event.code}, Reason: "${event.reason}"`, {
wasClean: event.wasClean,
jobId,
timeConnected: ws ? `${Date.now() - (ws as any)._connectTime || 0}ms` : 'unknown',
reconnectAttempts
});
setIsConnected(false);
// Detailed logging for different close codes
switch (event.code) {
case 1000:
if (event.reason === 'Auth session changed - reconnecting') {
console.log('🔄 WebSocket closed for auth session change - will reconnect immediately');
} else {
console.log('🔒 WebSocket closed normally');
}
break;
case 1006:
console.log('⚠️ WebSocket closed abnormally (1006) - likely server-side issue or network problem');
break;
case 1001:
console.log('🔄 WebSocket endpoint going away');
break;
case 1003:
console.log('❌ WebSocket unsupported data received');
break;
default:
console.log(`❓ WebSocket closed with code ${event.code}`);
}
// Handle auth session change reconnection (immediate reconnect)
if (event.code === 1000 && event.reason === 'Auth session changed - reconnecting') {
console.log('🔄 Reconnecting immediately due to auth session change...');
reconnectTimer = setTimeout(() => {
connect(); // Reconnect immediately with new session token
}, WEBSOCKET_RECONNECT_INITIAL_DELAY_MS); // Short delay to allow cleanup
return;
}
// Try to reconnect if not manually disconnected and haven't exceeded max attempts
if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) {
const delay = Math.min(
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS * Math.pow(2, reconnectAttempts),
WEBSOCKET_RECONNECT_MAX_DELAY_MS
); // Exponential backoff
console.log(`🔄 Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`);
reconnectTimer = setTimeout(() => {
reconnectAttempts++;
connect();
}, delay);
} else if (reconnectAttempts >= maxReconnectAttempts) {
console.log(`❌ Max reconnection attempts (${maxReconnectAttempts}) reached. Giving up.`);
setConnectionError(`Connection failed after ${maxReconnectAttempts} attempts. The training job may not exist or the server may be unavailable.`);
}
};
} catch (error) {
console.error('Error creating WebSocket connection:', error);
setConnectionError('Failed to create WebSocket connection');
}
};
// Connect immediately to avoid missing early progress updates
console.log('🚀 Starting immediate WebSocket connection...');
connect();
// Cleanup function
return () => {
isManuallyDisconnected = true;
if (reconnectTimer) {
clearTimeout(reconnectTimer);
}
if (ws) {
ws.close(1000, 'Component unmounted');
}
setIsConnected(false);
};
}, [tenantId, jobId, queryClient, memoizedOptions]);
return {
isConnected,
connectionError
};
};
// Utility Hooks
export const useIsTrainingInProgress = (
tenantId: string,
jobId?: string,
isWebSocketConnected?: boolean
) => {
const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', {
enabled: !!jobId,
isWebSocketConnected,
});
return jobStatus?.status === 'running' || jobStatus?.status === 'pending';
};
export const useTrainingProgress = (
tenantId: string,
jobId?: string,
isWebSocketConnected?: boolean
) => {
const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', {
enabled: !!jobId,
isWebSocketConnected,
});
return {
progress: jobStatus?.progress || 0,
currentStep: jobStatus?.current_step,
isComplete: jobStatus?.status === 'completed',
isFailed: jobStatus?.status === 'failed',
isRunning: jobStatus?.status === 'running',
};
};