diff --git a/frontend/src/api/client/apiClient.ts b/frontend/src/api/client/apiClient.ts index d99161b0..39c55d49 100644 --- a/frontend/src/api/client/apiClient.ts +++ b/frontend/src/api/client/apiClient.ts @@ -91,7 +91,27 @@ class ApiClient { // Response interceptor for error handling and automatic token refresh this.client.interceptors.response.use( - (response) => response, + (response) => { + // Enhanced logging for token refresh header detection + const refreshSuggested = response.headers['x-token-refresh-suggested']; + if (refreshSuggested) { + console.log('πŸ” TOKEN REFRESH HEADER DETECTED:', { + url: response.config?.url, + method: response.config?.method, + status: response.status, + refreshSuggested, + hasRefreshToken: !!this.refreshToken, + currentTokenLength: this.authToken?.length || 0 + }); + } + + // Check if server suggests token refresh + if (refreshSuggested === 'true' && this.refreshToken) { + console.log('πŸ”„ Server suggests token refresh - refreshing proactively'); + this.proactiveTokenRefresh(); + } + return response; + }, async (error) => { const originalRequest = error.config; @@ -228,6 +248,40 @@ class ApiClient { } } + private async proactiveTokenRefresh() { + // Avoid multiple simultaneous proactive refreshes + if (this.isRefreshing) { + return; + } + + try { + this.isRefreshing = true; + console.log('πŸ”„ Proactively refreshing token...'); + + const response = await this.client.post('/auth/refresh', { + refresh_token: this.refreshToken + }); + + const { access_token, refresh_token } = response.data; + + // Update tokens + this.setAuthToken(access_token); + if (refresh_token) { + this.setRefreshToken(refresh_token); + } + + // Update auth store + await this.updateAuthStore(access_token, refresh_token); + + console.log('βœ… Proactive token refresh successful'); + } catch (error) { + console.warn('⚠️ Proactive token refresh failed:', error); + // Don't handle as auth failure here - let the next 401 handle it + } finally { + this.isRefreshing = false; + } + } + private async handleAuthFailure() { try { // Clear tokens @@ -275,6 +329,117 @@ class ApiClient { return this.tenantId; } + // Token synchronization methods for WebSocket connections + getCurrentValidToken(): string | null { + return this.authToken; + } + + async ensureValidToken(): Promise { + const originalToken = this.authToken; + const originalTokenShort = originalToken ? `${originalToken.slice(0, 20)}...${originalToken.slice(-10)}` : 'null'; + + console.log('πŸ” ensureValidToken() called:', { + hasToken: !!this.authToken, + tokenPreview: originalTokenShort, + isRefreshing: this.isRefreshing, + hasRefreshToken: !!this.refreshToken + }); + + // If we have a valid token, return it + if (this.authToken && !this.isTokenNearExpiry(this.authToken)) { + const expiryInfo = this.getTokenExpiryInfo(this.authToken); + console.log('βœ… Token is valid, returning current token:', { + tokenPreview: originalTokenShort, + expiryInfo + }); + return this.authToken; + } + + // If token is near expiry or expired, try to refresh + if (this.refreshToken && !this.isRefreshing) { + console.log('πŸ”„ Token needs refresh, attempting proactive refresh:', { + reason: this.authToken ? 'near expiry' : 'no token', + expiryInfo: this.authToken ? this.getTokenExpiryInfo(this.authToken) : 'N/A' + }); + + try { + await this.proactiveTokenRefresh(); + const newTokenShort = this.authToken ? `${this.authToken.slice(0, 20)}...${this.authToken.slice(-10)}` : 'null'; + const tokenChanged = originalToken !== this.authToken; + + console.log('βœ… Token refresh completed:', { + tokenChanged, + oldTokenPreview: originalTokenShort, + newTokenPreview: newTokenShort, + newExpiryInfo: this.authToken ? this.getTokenExpiryInfo(this.authToken) : 'N/A' + }); + + return this.authToken; + } catch (error) { + console.warn('❌ Failed to refresh token in ensureValidToken:', error); + return null; + } + } + + console.log('⚠️ Returning current token without refresh:', { + reason: this.isRefreshing ? 'already refreshing' : 'no refresh token', + tokenPreview: originalTokenShort + }); + return this.authToken; + } + + private getTokenExpiryInfo(token: string): any { + try { + const payload = JSON.parse(atob(token.split('.')[1])); + const exp = payload.exp; + const iat = payload.iat; + if (!exp) return { error: 'No expiry in token' }; + + const now = Math.floor(Date.now() / 1000); + const timeUntilExpiry = exp - now; + const tokenLifetime = exp - iat; + + return { + issuedAt: new Date(iat * 1000).toISOString(), + expiresAt: new Date(exp * 1000).toISOString(), + lifetimeMinutes: Math.floor(tokenLifetime / 60), + secondsUntilExpiry: timeUntilExpiry, + minutesUntilExpiry: Math.floor(timeUntilExpiry / 60), + isNearExpiry: timeUntilExpiry < 300, + isExpired: timeUntilExpiry <= 0 + }; + } catch (error) { + return { error: 'Failed to parse token', details: error }; + } + } + + private isTokenNearExpiry(token: string): boolean { + try { + const payload = JSON.parse(atob(token.split('.')[1])); + const exp = payload.exp; + if (!exp) return false; + + const now = Math.floor(Date.now() / 1000); + const timeUntilExpiry = exp - now; + + // Consider token near expiry if less than 5 minutes remaining + const isNear = timeUntilExpiry < 300; + + if (isNear) { + console.log('⏰ Token is near expiry:', { + secondsUntilExpiry: timeUntilExpiry, + minutesUntilExpiry: Math.floor(timeUntilExpiry / 60), + expiresAt: new Date(exp * 1000).toISOString() + }); + } + + return isNear; + } catch (error) { + console.warn('Failed to parse token for expiry check:', error); + return true; // Assume expired if we can't parse + } + } + // HTTP Methods - Return direct data for React Query async get(url: string, config?: AxiosRequestConfig): Promise { const response: AxiosResponse = await this.client.get(url, config); diff --git a/frontend/src/api/hooks/onboarding.ts b/frontend/src/api/hooks/onboarding.ts index 10328140..e04086b7 100644 --- a/frontend/src/api/hooks/onboarding.ts +++ b/frontend/src/api/hooks/onboarding.ts @@ -90,24 +90,35 @@ export const useUpdateStep = ( export const useMarkStepCompleted = ( options?: UseMutationOptions< - UserProgress, - ApiError, + UserProgress, + ApiError, { userId: string; stepName: string; data?: Record } > ) => { const queryClient = useQueryClient(); - + return useMutation< - UserProgress, - ApiError, + UserProgress, + ApiError, { userId: string; stepName: string; data?: Record } >({ - mutationFn: ({ userId, stepName, data }) => + mutationFn: ({ userId, stepName, data }) => onboardingService.markStepCompleted(userId, stepName, data), onSuccess: (data, { userId }) => { - // Update progress cache + // Update progress cache with new data queryClient.setQueryData(onboardingKeys.progress(userId), data); + + // Invalidate the query to ensure fresh data on next access + queryClient.invalidateQueries({ queryKey: onboardingKeys.progress(userId) }); }, + onError: (error, { userId, stepName }) => { + console.error(`Failed to complete step ${stepName} for user ${userId}:`, error); + + // Invalidate queries on error to ensure we get fresh data + queryClient.invalidateQueries({ queryKey: onboardingKeys.progress(userId) }); + }, + // Prevent duplicate requests by using the step name as a mutation key + mutationKey: (variables) => ['markStepCompleted', variables?.userId, variables?.stepName], ...options, }); }; diff --git a/frontend/src/api/hooks/training.ts b/frontend/src/api/hooks/training.ts index 02111cd0..9a8176c7 100644 --- a/frontend/src/api/hooks/training.ts +++ b/frontend/src/api/hooks/training.ts @@ -3,10 +3,10 @@ * Provides data fetching, caching, and state management for training operations */ -import React from 'react'; +import * as React from 'react'; import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOptions } from '@tanstack/react-query'; import { trainingService } from '../services/training'; -import { ApiError } from '../client/apiClient'; +import { ApiError, apiClient } from '../client/apiClient'; import { useAuthStore } from '../../stores/auth.store'; import type { TrainingJobRequest, @@ -53,15 +53,62 @@ export const trainingKeys = { export const useTrainingJobStatus = ( tenantId: string, jobId: string, - options?: Omit, 'queryKey' | 'queryFn'> + options?: Omit, 'queryKey' | 'queryFn'> & { + isWebSocketConnected?: boolean; + } ) => { + const { isWebSocketConnected, ...queryOptions } = options || {}; + + // Completely disable the query when WebSocket is connected + const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected; + + console.log('πŸ”„ Training status query:', { + tenantId: !!tenantId, + jobId: !!jobId, + isWebSocketConnected, + queryEnabled: isEnabled + }); + return useQuery({ queryKey: trainingKeys.jobs.status(tenantId, jobId), - queryFn: () => trainingService.getTrainingJobStatus(tenantId, jobId), - enabled: !!tenantId && !!jobId, - refetchInterval: 5000, // Poll every 5 seconds while training + queryFn: () => { + console.log('πŸ“‘ Executing HTTP training status query (WebSocket disconnected)'); + return trainingService.getTrainingJobStatus(tenantId, jobId); + }, + enabled: isEnabled, // Completely disable when WebSocket connected + refetchInterval: (query) => { + // CRITICAL FIX: React Query executes refetchInterval even when enabled=false + // We must check WebSocket connection state here to prevent misleading polling + if (isWebSocketConnected) { + console.log('βœ… WebSocket connected - HTTP polling DISABLED'); + return false; // Disable polling when WebSocket is active + } + + const data = query.state.data; + + // Stop polling if we get auth errors or training is completed + if (query.state.error && (query.state.error as any)?.status === 401) { + console.log('🚫 Stopping status polling due to auth error'); + return false; + } + if (data?.status === 'completed' || data?.status === 'failed') { + console.log('🏁 Training completed - stopping HTTP polling'); + return false; // Stop polling when training is done + } + + console.log('πŸ“Š HTTP fallback polling active (WebSocket actually disconnected) - 5s interval'); + return 5000; // Poll every 5 seconds while training (fallback when WebSocket unavailable) + }, staleTime: 1000, // Consider data stale after 1 second - ...options, + retry: (failureCount, error) => { + // Don't retry on auth errors + if ((error as any)?.status === 401) { + console.log('🚫 Not retrying due to auth error'); + return false; + } + return failureCount < 3; + }, + ...queryOptions, }); }; @@ -242,9 +289,9 @@ export const useTrainingWebSocket = ( } ) => { const queryClient = useQueryClient(); - const authToken = useAuthStore((state) => state.token); const [isConnected, setIsConnected] = React.useState(false); const [connectionError, setConnectionError] = React.useState(null); + const [connectionAttempts, setConnectionAttempts] = React.useState(0); // Memoize options to prevent unnecessary effect re-runs const memoizedOptions = React.useMemo(() => options, [ @@ -266,20 +313,44 @@ export const useTrainingWebSocket = ( let reconnectAttempts = 0; const maxReconnectAttempts = 3; - const connect = () => { + const connect = async () => { try { setConnectionError(null); - const effectiveToken = token || authToken; + setConnectionAttempts(prev => prev + 1); + + // Use centralized token management from apiClient + let effectiveToken: string | null; + + try { + // Always use the apiClient's token management + effectiveToken = await apiClient.ensureValidToken(); + + if (!effectiveToken) { + throw new Error('No valid token available'); + } + } catch (error) { + console.error('❌ Failed to get valid token for WebSocket:', error); + setConnectionError('Authentication failed. Please log in again.'); + return; + } + console.log(`πŸ”„ Attempting WebSocket connection (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts + 1}):`, { tenantId, jobId, - hasToken: !!effectiveToken + hasToken: !!effectiveToken, + tokenFromApiClient: true }); - - ws = trainingService.createWebSocketConnection(tenantId, jobId, token || authToken || undefined); + + ws = trainingService.createWebSocketConnection(tenantId, jobId, effectiveToken); ws.onopen = () => { - console.log('βœ… Training WebSocket connected successfully'); + console.log('βœ… Training WebSocket connected successfully', { + readyState: ws?.readyState, + url: ws?.url, + jobId + }); + // Track connection time for debugging + (ws as any)._connectTime = Date.now(); setIsConnected(true); reconnectAttempts = 0; // Reset on successful connection @@ -291,23 +362,81 @@ export const useTrainingWebSocket = ( console.warn('Failed to request status on connection:', e); } - // Set up periodic ping to keep connection alive - const pingInterval = setInterval(() => { + // Helper function to check if tokens represent different auth sessions + const isTokenSessionDifferent = (oldToken: string, newToken: string): boolean => { + if (!oldToken || !newToken) return !!oldToken !== !!newToken; + + try { + const oldPayload = JSON.parse(atob(oldToken.split('.')[1])); + const newPayload = JSON.parse(atob(newToken.split('.')[1])); + + // Compare by issued timestamp (iat) - different iat means new auth session + return oldPayload.iat !== newPayload.iat; + } catch (e) { + console.warn('Failed to parse token for session comparison, falling back to string comparison:', e); + return oldToken !== newToken; + } + }; + + // Set up periodic ping and intelligent token refresh detection + const heartbeatInterval = setInterval(async () => { if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) { try { + // Check token validity (this may refresh if needed) + const currentToken = await apiClient.ensureValidToken(); + + // Enhanced token change detection with detailed logging + const tokenStringChanged = currentToken !== effectiveToken; + const tokenSessionChanged = currentToken && effectiveToken ? + isTokenSessionDifferent(effectiveToken, currentToken) : tokenStringChanged; + + console.log('πŸ” WebSocket token validation check:', { + hasCurrentToken: !!currentToken, + hasEffectiveToken: !!effectiveToken, + tokenStringChanged, + tokenSessionChanged, + currentTokenPreview: currentToken ? `${currentToken.slice(0, 20)}...${currentToken.slice(-10)}` : 'null', + effectiveTokenPreview: effectiveToken ? `${effectiveToken.slice(0, 20)}...${effectiveToken.slice(-10)}` : 'null' + }); + + // Only reconnect if we have a genuine session change (different iat) + if (tokenSessionChanged) { + console.log('πŸ”„ Token session changed - reconnecting WebSocket with new session token'); + console.log('πŸ“Š Session change details:', { + reason: !currentToken ? 'token removed' : + !effectiveToken ? 'token added' : 'new auth session', + oldTokenIat: effectiveToken ? (() => { + try { return JSON.parse(atob(effectiveToken.split('.')[1])).iat; } catch { return 'parse-error'; } + })() : 'N/A', + newTokenIat: currentToken ? (() => { + try { return JSON.parse(atob(currentToken.split('.')[1])).iat; } catch { return 'parse-error'; } + })() : 'N/A' + }); + + // Close current connection and trigger reconnection with new token + ws?.close(1000, 'Token session changed - reconnecting'); + clearInterval(heartbeatInterval); + return; + } else if (tokenStringChanged) { + console.log('ℹ️ Token string changed but same session - continuing with current connection'); + // Update effective token reference for future comparisons + effectiveToken = currentToken; + } + + console.log('βœ… Token validated during heartbeat - same session'); ws?.send('ping'); - console.log('πŸ’“ Sent ping to server'); + console.log('πŸ’“ Sent ping to server (token session validated)'); } catch (e) { - console.warn('Failed to send ping:', e); - clearInterval(pingInterval); + console.warn('Failed to send ping or validate token:', e); + clearInterval(heartbeatInterval); } } else { - clearInterval(pingInterval); + clearInterval(heartbeatInterval); } - }, 30000); // Ping every 30 seconds + }, 30000); // Check every 30 seconds for token refresh and send ping // Store interval for cleanup - (ws as any).pingInterval = pingInterval; + (ws as any).heartbeatInterval = heartbeatInterval; }; ws.onmessage = (event) => { @@ -404,13 +533,22 @@ export const useTrainingWebSocket = ( }; ws.onclose = (event) => { - console.log(`❌ Training WebSocket disconnected. Code: ${event.code}, Reason: "${event.reason}"`); + console.log(`❌ Training WebSocket disconnected. Code: ${event.code}, Reason: "${event.reason}"`, { + wasClean: event.wasClean, + jobId, + timeConnected: ws ? `${Date.now() - (ws as any)._connectTime || 0}ms` : 'unknown', + reconnectAttempts + }); setIsConnected(false); // Detailed logging for different close codes switch (event.code) { case 1000: - console.log('πŸ”’ WebSocket closed normally'); + if (event.reason === 'Token refreshed - reconnecting') { + console.log('πŸ”„ WebSocket closed for token refresh - will reconnect immediately'); + } else { + console.log('πŸ”’ WebSocket closed normally'); + } break; case 1006: console.log('⚠️ WebSocket closed abnormally (1006) - likely server-side issue or network problem'); @@ -425,11 +563,20 @@ export const useTrainingWebSocket = ( console.log(`❓ WebSocket closed with code ${event.code}`); } + // Handle token refresh reconnection (immediate reconnect) + if (event.code === 1000 && event.reason === 'Token refreshed - reconnecting') { + console.log('πŸ”„ Reconnecting immediately due to token refresh...'); + reconnectTimer = setTimeout(() => { + connect(); // Reconnect immediately with fresh token + }, 1000); // Short delay to allow cleanup + return; + } + // Try to reconnect if not manually disconnected and haven't exceeded max attempts if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) { const delay = Math.min(1000 * Math.pow(2, reconnectAttempts), 10000); // Exponential backoff, max 10s console.log(`πŸ”„ Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`); - + reconnectTimer = setTimeout(() => { reconnectAttempts++; connect(); @@ -470,7 +617,7 @@ export const useTrainingWebSocket = ( setIsConnected(false); }; - }, [tenantId, jobId, token, authToken, queryClient, memoizedOptions]); + }, [tenantId, jobId, queryClient, memoizedOptions]); return { isConnected, @@ -479,17 +626,27 @@ export const useTrainingWebSocket = ( }; // Utility Hooks -export const useIsTrainingInProgress = (tenantId: string, jobId?: string) => { +export const useIsTrainingInProgress = ( + tenantId: string, + jobId?: string, + isWebSocketConnected?: boolean +) => { const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', { enabled: !!jobId, + isWebSocketConnected, }); return jobStatus?.status === 'running' || jobStatus?.status === 'pending'; }; -export const useTrainingProgress = (tenantId: string, jobId?: string) => { +export const useTrainingProgress = ( + tenantId: string, + jobId?: string, + isWebSocketConnected?: boolean +) => { const { data: jobStatus } = useTrainingJobStatus(tenantId, jobId || '', { enabled: !!jobId, + isWebSocketConnected, }); return { diff --git a/frontend/src/components/domain/onboarding/OnboardingWizard.tsx b/frontend/src/components/domain/onboarding/OnboardingWizard.tsx index 6e09afc2..090424be 100644 --- a/frontend/src/components/domain/onboarding/OnboardingWizard.tsx +++ b/frontend/src/components/domain/onboarding/OnboardingWizard.tsx @@ -94,18 +94,21 @@ export const OnboardingWizard: React.FC = () => { const { setCurrentTenant } = useTenantActions(); // Auto-complete user_registered step if needed (runs first) + const [autoCompletionAttempted, setAutoCompletionAttempted] = React.useState(false); + useEffect(() => { - if (userProgress && user?.id) { + if (userProgress && user?.id && !autoCompletionAttempted && !markStepCompleted.isPending) { const userRegisteredStep = userProgress.steps.find(s => s.step_name === 'user_registered'); - + if (!userRegisteredStep?.completed) { console.log('πŸ”„ Auto-completing user_registered step for new user...'); - + setAutoCompletionAttempted(true); + markStepCompleted.mutate({ userId: user.id, stepName: 'user_registered', - data: { - auto_completed: true, + data: { + auto_completed: true, completed_at: new Date().toISOString(), source: 'onboarding_wizard_auto_completion' } @@ -116,11 +119,13 @@ export const OnboardingWizard: React.FC = () => { }, onError: (error) => { console.error('❌ Failed to auto-complete user_registered step:', error); + // Reset flag on error to allow retry + setAutoCompletionAttempted(false); } }); } } - }, [userProgress, user?.id, markStepCompleted]); + }, [userProgress, user?.id, autoCompletionAttempted, markStepCompleted.isPending]); // Removed markStepCompleted from deps // Initialize step index based on backend progress with validation useEffect(() => { @@ -205,6 +210,12 @@ export const OnboardingWizard: React.FC = () => { return; } + // Prevent concurrent mutations + if (markStepCompleted.isPending) { + console.warn(`⚠️ Step completion already in progress for "${currentStep.id}", skipping duplicate call`); + return; + } + console.log(`🎯 Completing step: "${currentStep.id}" with data:`, data); try { @@ -260,25 +271,50 @@ export const OnboardingWizard: React.FC = () => { } } catch (error: any) { console.error(`❌ Error completing step "${currentStep.id}":`, error); - + // Extract detailed error information const errorMessage = error?.response?.data?.detail || error?.message || 'Unknown error'; const statusCode = error?.response?.status; - + console.error(`πŸ“Š Error details: Status ${statusCode}, Message: ${errorMessage}`); - + + // Handle different types of errors + if (statusCode === 207) { + // Multi-Status: Step updated but summary failed + console.warn(`⚠️ Partial success for step "${currentStep.id}": ${errorMessage}`); + + // Continue with step advancement since the actual step was completed + if (currentStep.id === 'completion') { + // Navigate to dashboard after completion + if (isNewTenant) { + navigate('/app/dashboard'); + } else { + navigate('/app'); + } + } else { + // Auto-advance to next step after successful completion + if (currentStepIndex < STEPS.length - 1) { + setCurrentStepIndex(currentStepIndex + 1); + } + } + + // Show a warning but don't block progress + console.warn(`Step "${currentStep.title}" completed with warnings: ${errorMessage}`); + return; // Don't show error alert + } + // Check if it's a dependency error if (errorMessage.includes('dependencies not met')) { console.error('🚫 Dependencies not met for step:', currentStep.id); - + // Check what dependencies are missing if (userProgress) { console.log('πŸ“‹ Current progress:', userProgress); console.log('πŸ“‹ Completed steps:', userProgress.steps.filter(s => s.completed).map(s => s.step_name)); } } - - // Don't advance automatically on error - user should see the issue + + // Don't advance automatically on real errors - user should see the issue alert(`${t('onboarding:errors.step_failed', 'Error al completar paso')} "${currentStep.title}": ${errorMessage}`); } }; diff --git a/frontend/src/components/domain/onboarding/steps/MLTrainingStep.tsx b/frontend/src/components/domain/onboarding/steps/MLTrainingStep.tsx index 87c7e13c..5a1fc589 100644 --- a/frontend/src/components/domain/onboarding/steps/MLTrainingStep.tsx +++ b/frontend/src/components/domain/onboarding/steps/MLTrainingStep.tsx @@ -1,7 +1,7 @@ import React, { useState, useCallback, useEffect } from 'react'; import { Button } from '../../../ui/Button'; import { useCurrentTenant } from '../../../../stores/tenant.store'; -import { useCreateTrainingJob, useTrainingWebSocket } from '../../../../api/hooks/training'; +import { useCreateTrainingJob, useTrainingWebSocket, useTrainingJobStatus } from '../../../../api/hooks/training'; interface MLTrainingStepProps { onNext: () => void; @@ -85,6 +85,63 @@ export const MLTrainingStep: React.FC = ({ } : undefined ); + // Smart fallback polling - automatically disabled when WebSocket is connected + const { data: jobStatus } = useTrainingJobStatus( + currentTenant?.id || '', + jobId || '', + { + enabled: !!jobId && !!currentTenant?.id, + isWebSocketConnected: isConnected, // This will disable HTTP polling when WebSocket is connected + } + ); + + // Handle training status updates from HTTP polling (fallback only) + useEffect(() => { + if (!jobStatus || !jobId || trainingProgress?.stage === 'completed') { + return; + } + + console.log('πŸ“Š HTTP fallback status update:', jobStatus); + + // Check if training completed via HTTP polling fallback + if (jobStatus.status === 'completed' && trainingProgress?.stage !== 'completed') { + console.log('βœ… Training completion detected via HTTP fallback'); + setTrainingProgress({ + stage: 'completed', + progress: 100, + message: 'Entrenamiento completado exitosamente (detectado por verificaciΓ³n HTTP)' + }); + setIsTraining(false); + + setTimeout(() => { + onComplete({ + jobId: jobId, + success: true, + message: 'Modelo entrenado correctamente', + detectedViaPolling: true + }); + }, 2000); + } else if (jobStatus.status === 'failed') { + console.log('❌ Training failure detected via HTTP fallback'); + setError('Error detectado durante el entrenamiento (verificaciΓ³n de estado)'); + setIsTraining(false); + setTrainingProgress(null); + } else if (jobStatus.status === 'running' && jobStatus.progress !== undefined) { + // Update progress if we have newer information from HTTP polling fallback + const currentProgress = trainingProgress?.progress || 0; + if (jobStatus.progress > currentProgress) { + console.log(`πŸ“ˆ Progress update via HTTP fallback: ${jobStatus.progress}%`); + setTrainingProgress(prev => ({ + ...prev, + stage: 'training', + progress: jobStatus.progress, + message: jobStatus.message || 'Entrenando modelo...', + currentStep: jobStatus.current_step + }) as TrainingProgress); + } + } + }, [jobStatus, jobId, trainingProgress?.stage, onComplete]); + // Auto-trigger training when component mounts useEffect(() => { if (currentTenant?.id && !isTraining && !trainingProgress && !error) { diff --git a/gateway/app/main.py b/gateway/app/main.py index 197b1a6a..abc37cc2 100644 --- a/gateway/app/main.py +++ b/gateway/app/main.py @@ -249,60 +249,193 @@ async def events_stream(request: Request, token: str): @app.websocket("/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live") async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str): - """WebSocket proxy that forwards connections directly to training service""" + """WebSocket proxy that forwards connections directly to training service with enhanced token validation""" await websocket.accept() - + # Get token from query params token = websocket.query_params.get("token") if not token: logger.warning(f"WebSocket connection rejected - missing token for job {job_id}") await websocket.close(code=1008, reason="Authentication token required") return - + + # Validate token using auth middleware + from app.middleware.auth import jwt_handler + try: + payload = jwt_handler.verify_token(token) + if not payload: + logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}") + await websocket.close(code=1008, reason="Invalid authentication token") + return + + # Check token expiration + import time + if payload.get('exp', 0) < time.time(): + logger.warning(f"WebSocket connection rejected - expired token for job {job_id}") + await websocket.close(code=1008, reason="Token expired") + return + + logger.info(f"WebSocket token validated for user {payload.get('email', 'unknown')}") + + except Exception as e: + logger.warning(f"WebSocket token validation failed for job {job_id}: {e}") + await websocket.close(code=1008, reason="Token validation failed") + return + logger.info(f"Proxying WebSocket connection to training service for job {job_id}, tenant {tenant_id}") - + # Build WebSocket URL to training service training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/') training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://') training_ws_url = f"{training_ws_url}/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}" - + + training_ws = None + heartbeat_task = None + try: - # Connect to training service WebSocket + # Connect to training service WebSocket with proper timeout configuration import websockets - async with websockets.connect(training_ws_url) as training_ws: - logger.info(f"Connected to training service WebSocket for job {job_id}") - - async def forward_to_training(): - """Forward messages from frontend to training service""" + + # Configure timeouts to coordinate with frontend (30s heartbeat) and training service + # DISABLE gateway-level ping to avoid dual-ping conflicts - let frontend handle ping/pong + training_ws = await websockets.connect( + training_ws_url, + ping_interval=None, # DISABLED: Let frontend handle ping/pong via message forwarding + ping_timeout=None, # DISABLED: No independent ping mechanism + close_timeout=15, # Reasonable close timeout + max_size=2**20, # 1MB max message size + max_queue=32 # Max queued messages + ) + + logger.info(f"Connected to training service WebSocket for job {job_id} with gateway ping DISABLED (frontend handles ping/pong)") + + # Track connection state properly due to FastAPI WebSocket state propagation bug + connection_alive = True + last_activity = asyncio.get_event_loop().time() + + async def check_connection_health(): + """Monitor connection health based on activity timestamps only - no WebSocket interference""" + nonlocal connection_alive, last_activity + + while connection_alive: try: - async for message in websocket.iter_text(): + await asyncio.sleep(30) # Check every 30 seconds (aligned with frontend heartbeat) + current_time = asyncio.get_event_loop().time() + + # Check if we haven't received any activity for too long + # Frontend sends ping every 30s, so 90s = 3 missed pings before considering dead + if current_time - last_activity > 90: + logger.warning(f"No frontend activity for 90s on job {job_id} - connection may be dead") + # Don't forcibly close - let the forwarding loops handle actual connection issues + # This is just monitoring/logging now + else: + logger.debug(f"Connection health OK for job {job_id} - last activity {int(current_time - last_activity)}s ago") + + except Exception as e: + logger.error(f"Connection health monitoring error for job {job_id}: {e}") + break + + async def forward_to_training(): + """Forward messages from frontend to training service with proper error handling""" + nonlocal connection_alive, last_activity + + try: + while connection_alive and training_ws and training_ws.open: + try: + # Use longer timeout to avoid conflicts with frontend 30s heartbeat + # Frontend sends ping every 30s, so we need to allow for some latency + message = await asyncio.wait_for(websocket.receive_text(), timeout=45.0) + last_activity = asyncio.get_event_loop().time() + + # Forward the message to training service await training_ws.send(message) - except Exception as e: - logger.error(f"Error forwarding to training service: {e}") - - async def forward_to_frontend(): - """Forward messages from training service to frontend""" - try: - async for message in training_ws: + logger.debug(f"Forwarded message to training service for job {job_id}: {message[:100]}...") + + except asyncio.TimeoutError: + # No message received in 45 seconds, continue loop + # This allows for frontend 30s heartbeat + network latency + processing time + continue + except Exception as e: + logger.error(f"Error receiving from frontend for job {job_id}: {e}") + connection_alive = False + break + + except Exception as e: + logger.error(f"Error in forward_to_training for job {job_id}: {e}") + connection_alive = False + + async def forward_to_frontend(): + """Forward messages from training service to frontend with proper error handling""" + nonlocal connection_alive, last_activity + + try: + while connection_alive and training_ws and training_ws.open: + try: + # Use coordinated timeout - training service expects messages every 60s + # This should be longer than training service timeout to avoid premature closure + message = await asyncio.wait_for(training_ws.recv(), timeout=75.0) + last_activity = asyncio.get_event_loop().time() + + # Forward the message to frontend await websocket.send_text(message) - except Exception as e: - logger.error(f"Error forwarding to frontend: {e}") - - # Run both forwarding tasks concurrently + logger.debug(f"Forwarded message to frontend for job {job_id}: {message[:100]}...") + + except asyncio.TimeoutError: + # No message received in 75 seconds, continue loop + # Training service sends heartbeats, so this indicates potential issues + continue + except Exception as e: + logger.error(f"Error receiving from training service for job {job_id}: {e}") + connection_alive = False + break + + except Exception as e: + logger.error(f"Error in forward_to_frontend for job {job_id}: {e}") + connection_alive = False + + # Start connection health monitoring + heartbeat_task = asyncio.create_task(check_connection_health()) + + # Run both forwarding tasks concurrently with proper error handling + try: await asyncio.gather( forward_to_training(), forward_to_frontend(), return_exceptions=True ) - + except Exception as e: + logger.error(f"Error in WebSocket forwarding tasks for job {job_id}: {e}") + finally: + connection_alive = False + + except websockets.exceptions.ConnectionClosedError as e: + logger.warning(f"Training service WebSocket connection closed for job {job_id}: {e}") + except websockets.exceptions.WebSocketException as e: + logger.error(f"WebSocket exception for job {job_id}: {e}") except Exception as e: logger.error(f"WebSocket proxy error for job {job_id}: {e}") - try: - await websocket.close(code=1011, reason="Training service connection failed") - except: - pass finally: - logger.info(f"WebSocket proxy closed for job {job_id}") + # Cleanup + if heartbeat_task and not heartbeat_task.done(): + heartbeat_task.cancel() + try: + await heartbeat_task + except asyncio.CancelledError: + pass + + if training_ws and not training_ws.closed: + try: + await training_ws.close() + except Exception as e: + logger.warning(f"Error closing training service WebSocket for job {job_id}: {e}") + + try: + if not websocket.client_state.name == 'DISCONNECTED': + await websocket.close(code=1000, reason="Proxy connection closed") + except Exception as e: + logger.warning(f"Error closing frontend WebSocket for job {job_id}: {e}") + + logger.info(f"WebSocket proxy cleanup completed for job {job_id}") if __name__ == "__main__": import uvicorn diff --git a/gateway/app/middleware/auth.py b/gateway/app/middleware/auth.py index 4812a619..af1d1f4f 100644 --- a/gateway/app/middleware/auth.py +++ b/gateway/app/middleware/auth.py @@ -67,7 +67,7 @@ class AuthMiddleware(BaseHTTPMiddleware): ) # βœ… STEP 2: Verify token and get user context - user_context = await self._verify_token(token) + user_context = await self._verify_token(token, request) if not user_context: logger.warning(f"Invalid token for route: {request.url.path}") return JSONResponse( @@ -117,7 +117,14 @@ class AuthMiddleware(BaseHTTPMiddleware): tenant_id=tenant_id, path=request.url.path) - return await call_next(request) + # Process the request + response = await call_next(request) + + # Add token expiry warning header if token is near expiry + if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry: + response.headers["X-Token-Refresh-Suggested"] = "true" + + return response def _is_public_route(self, path: str) -> bool: """Check if route requires authentication""" @@ -130,7 +137,7 @@ class AuthMiddleware(BaseHTTPMiddleware): return auth_header.split(" ")[1] return None - async def _verify_token(self, token: str) -> Optional[Dict[str, Any]]: + async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]: """ Verify JWT token with improved fallback strategy FIXED: Better error handling and token structure validation @@ -141,6 +148,17 @@ class AuthMiddleware(BaseHTTPMiddleware): payload = jwt_handler.verify_token(token) if payload and self._validate_token_payload(payload): logger.debug("Token validated locally") + + # Check if token is near expiry and set flag for response header + if request: + import time + exp_time = payload.get("exp", 0) + current_time = time.time() + time_until_expiry = exp_time - current_time + + if time_until_expiry < 300: # 5 minutes + request.state.token_near_expiry = True + # Convert JWT payload to user context format return self._jwt_payload_to_user_context(payload) except Exception as e: @@ -177,18 +195,26 @@ class AuthMiddleware(BaseHTTPMiddleware): """ required_fields = ["user_id", "email", "exp", "type"] missing_fields = [field for field in required_fields if field not in payload] - + if missing_fields: logger.warning(f"Token payload missing fields: {missing_fields}") return False - + # Validate token type token_type = payload.get("type") if token_type not in ["access", "service"]: logger.warning(f"Invalid token type: {payload.get('type')}") return False - + # Check if token is near expiry (within 5 minutes) and log warning + import time + exp_time = payload.get("exp", 0) + current_time = time.time() + time_until_expiry = exp_time - current_time + + if time_until_expiry < 300: # 5 minutes + logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}") + return True def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]: diff --git a/infrastructure/kubernetes/base/configmap.yaml b/infrastructure/kubernetes/base/configmap.yaml index b79352dc..00bfdae9 100644 --- a/infrastructure/kubernetes/base/configmap.yaml +++ b/infrastructure/kubernetes/base/configmap.yaml @@ -88,7 +88,7 @@ data: # AUTHENTICATION & SECURITY SETTINGS # ================================================================ JWT_ALGORITHM: "HS256" - JWT_ACCESS_TOKEN_EXPIRE_MINUTES: "30" + JWT_ACCESS_TOKEN_EXPIRE_MINUTES: "240" JWT_REFRESH_TOKEN_EXPIRE_DAYS: "7" ENABLE_SERVICE_AUTH: "false" PASSWORD_MIN_LENGTH: "8" diff --git a/services/auth/app/api/onboarding.py b/services/auth/app/api/onboarding.py index c46cd88f..c985e4a8 100644 --- a/services/auth/app/api/onboarding.py +++ b/services/auth/app/api/onboarding.py @@ -131,15 +131,28 @@ class OnboardingService: # Update the step await self._update_user_onboarding_data( - user_id, - step_name, + user_id, + step_name, { "completed": update_request.completed, "completed_at": datetime.now(timezone.utc).isoformat() if update_request.completed else None, "data": update_request.data or {} } ) - + + # Try to update summary and handle partial failures gracefully + try: + # Update the user's onboarding summary + await self._update_user_summary(user_id) + except HTTPException as he: + # If it's a 207 Multi-Status (partial success), log warning but continue + if he.status_code == status.HTTP_207_MULTI_STATUS: + logger.warning(f"Summary update failed for user {user_id}, step {step_name}: {he.detail}") + # Continue execution - the step update was successful + else: + # Re-raise other HTTP exceptions + raise + # Return updated progress return await self.get_user_progress(user_id) @@ -284,10 +297,7 @@ class OnboardingService: completed=completed, step_data=data_payload ) - - # Update the user's onboarding summary - await self._update_user_summary(user_id) - + logger.info(f"Successfully updated onboarding step for user {user_id}: {step_name} = {step_data}") return updated_step @@ -300,26 +310,26 @@ class OnboardingService: try: # Get updated progress user_progress_data = await self._get_user_onboarding_data(user_id) - + # Calculate current status completed_steps = [] for step_name in ONBOARDING_STEPS: if user_progress_data.get(step_name, {}).get("completed", False): completed_steps.append(step_name) - + # Determine current and next step current_step = self._get_current_step(completed_steps) next_step = self._get_next_step(completed_steps) - + # Calculate completion percentage completion_percentage = (len(completed_steps) / len(ONBOARDING_STEPS)) * 100 - + # Check if fully completed fully_completed = len(completed_steps) == len(ONBOARDING_STEPS) - + # Format steps count steps_completed_count = f"{len(completed_steps)}/{len(ONBOARDING_STEPS)}" - + # Update summary in database await self.onboarding_repo.upsert_user_summary( user_id=user_id, @@ -329,10 +339,18 @@ class OnboardingService: fully_completed=fully_completed, steps_completed_count=steps_completed_count ) - + + logger.debug(f"Successfully updated onboarding summary for user {user_id}") + except Exception as e: - logger.error(f"Error updating onboarding summary for user {user_id}: {e}") - # Don't raise here - summary update failure shouldn't break step updates + logger.error(f"Error updating onboarding summary for user {user_id}: {e}", + extra={"user_id": user_id, "error_type": type(e).__name__}) + # Raise a warning-level HTTPException to inform frontend without breaking the flow + # This allows the step update to succeed while alerting about summary issues + raise HTTPException( + status_code=status.HTTP_207_MULTI_STATUS, + detail=f"Step updated successfully, but summary update failed: {str(e)}" + ) # API Routes diff --git a/services/auth/app/models/onboarding.py b/services/auth/app/models/onboarding.py index 552828bc..bfe5c0cd 100644 --- a/services/auth/app/models/onboarding.py +++ b/services/auth/app/models/onboarding.py @@ -61,11 +61,11 @@ class UserOnboardingSummary(Base): # Summary fields current_step = Column(String(50), nullable=False, default="user_registered") next_step = Column(String(50)) - completion_percentage = Column(String(10), default="0.0") # Store as string for precision + completion_percentage = Column(String(50), default="0.0") # Store as string for precision fully_completed = Column(Boolean, default=False) - + # Progress tracking - steps_completed_count = Column(String(10), default="0") # Store as string: "3/5" + steps_completed_count = Column(String(50), default="0") # Store as string: "3/5" # Timestamps created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) diff --git a/services/training/app/api/websocket.py b/services/training/app/api/websocket.py index cde142df..4446844f 100644 --- a/services/training/app/api/websocket.py +++ b/services/training/app/api/websocket.py @@ -82,11 +82,45 @@ async def training_progress_websocket( tenant_id: str, job_id: str ): - connection_id = f"{tenant_id}_{id(websocket)}" - + # Validate token from query parameters + token = websocket.query_params.get("token") + if not token: + logger.warning(f"WebSocket connection rejected - missing token for job {job_id}") + await websocket.close(code=1008, reason="Authentication token required") + return + + # Validate the token (use the same JWT handler as gateway) + from shared.auth.jwt_handler import JWTHandler + from app.core.config import settings + + jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM) + + try: + payload = jwt_handler.verify_token(token) + if not payload: + logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}") + await websocket.close(code=1008, reason="Invalid authentication token") + return + + # Verify user has access to this tenant + user_id = payload.get('user_id') + if not user_id: + logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}") + await websocket.close(code=1008, reason="Invalid token payload") + return + + logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}") + + except Exception as e: + logger.warning(f"WebSocket token validation failed for job {job_id}: {e}") + await websocket.close(code=1008, reason="Token validation failed") + return + + connection_id = f"{tenant_id}_{user_id}_{id(websocket)}" + await connection_manager.connect(websocket, job_id, connection_id) - logger.info(f"WebSocket connection established for job {job_id}") - + logger.info(f"WebSocket connection established for job {job_id}, user {user_id}") + consumer_task = None training_completed = False @@ -100,11 +134,12 @@ async def training_progress_websocket( while not training_completed: try: - # FIXED: Use receive() instead of receive_text() + # Coordinate with frontend 30s heartbeat + gateway 45s timeout + # This should be longer than gateway timeout to avoid premature closure try: - data = await asyncio.wait_for(websocket.receive(), timeout=30.0) + data = await asyncio.wait_for(websocket.receive(), timeout=60.0) last_activity = asyncio.get_event_loop().time() - + # Handle different message types if data["type"] == "websocket.receive": if "text" in data: @@ -123,31 +158,41 @@ async def training_progress_websocket( elif message_text == "close": logger.info(f"Client requested connection close for job {job_id}") break - + elif "bytes" in data: - # Handle binary messages (WebSocket ping frames) + # Handle binary messages (WebSocket ping frames) - respond with text pong for compatibility await websocket.send_text("pong") - logger.debug(f"Binary ping received for job {job_id}") - + logger.debug(f"Binary ping received for job {job_id}, responding with text pong") + elif data["type"] == "websocket.disconnect": logger.info(f"WebSocket disconnect message received for job {job_id}") break - + except asyncio.TimeoutError: - # No message received in 30 seconds - send heartbeat + # No message received in 60 seconds - this is now coordinated with gateway timeouts current_time = asyncio.get_event_loop().time() - if current_time - last_activity > 60: - logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat") - - try: - await websocket.send_json({ - "type": "heartbeat", - "job_id": job_id, - "timestamp": str(datetime.datetime.now()) - }) - except Exception as e: - logger.error(f"Failed to send heartbeat for job {job_id}: {e}") - break + + # Send heartbeat only if we haven't received frontend ping for too long + # Frontend sends ping every 30s, so 60s timeout + 30s grace = 90s before heartbeat + if current_time - last_activity > 90: # 90 seconds of total inactivity + logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat") + + try: + await websocket.send_json({ + "type": "heartbeat", + "job_id": job_id, + "timestamp": str(datetime.datetime.now()), + "message": "Training service heartbeat - frontend inactive", + "inactivity_seconds": int(current_time - last_activity) + }) + last_activity = current_time + except Exception as e: + logger.error(f"Failed to send heartbeat for job {job_id}: {e}") + break + else: + # Normal timeout, frontend should be sending ping every 30s + logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)") + continue except WebSocketDisconnect: logger.info(f"WebSocket client disconnected for job {job_id}") diff --git a/services/training/app/core/database.py b/services/training/app/core/database.py index 540125dd..23fd93ca 100644 --- a/services/training/app/core/database.py +++ b/services/training/app/core/database.py @@ -43,11 +43,71 @@ async def get_db_health() -> bool: await conn.execute(text("SELECT 1")) logger.debug("Database health check passed") return True - + except Exception as e: logger.error("Database health check failed", error=str(e)) return False +async def get_comprehensive_db_health() -> dict: + """ + Comprehensive health check that verifies both connectivity and table existence + """ + health_status = { + "status": "healthy", + "connectivity": False, + "tables_exist": False, + "tables_verified": [], + "missing_tables": [], + "errors": [] + } + + try: + # Test basic connectivity + health_status["connectivity"] = await get_db_health() + + if not health_status["connectivity"]: + health_status["status"] = "unhealthy" + health_status["errors"].append("Database connectivity failed") + return health_status + + # Test table existence + tables_verified = await _verify_tables_exist() + health_status["tables_exist"] = tables_verified + + if tables_verified: + health_status["tables_verified"] = [ + 'model_training_logs', 'trained_models', 'model_performance_metrics', + 'training_job_queue', 'model_artifacts' + ] + else: + health_status["status"] = "unhealthy" + health_status["errors"].append("Required tables missing or inaccessible") + + # Try to identify which specific tables are missing + try: + async with database_manager.get_session() as session: + for table_name in ['model_training_logs', 'trained_models', 'model_performance_metrics', + 'training_job_queue', 'model_artifacts']: + try: + await session.execute(text(f"SELECT 1 FROM {table_name} LIMIT 1")) + health_status["tables_verified"].append(table_name) + except Exception: + health_status["missing_tables"].append(table_name) + except Exception as e: + health_status["errors"].append(f"Error checking individual tables: {str(e)}") + + logger.debug("Comprehensive database health check completed", + status=health_status["status"], + connectivity=health_status["connectivity"], + tables_exist=health_status["tables_exist"]) + + except Exception as e: + health_status["status"] = "unhealthy" + health_status["errors"].append(f"Health check failed: {str(e)}") + logger.error("Comprehensive database health check failed", error=str(e)) + + return health_status + # Training service specific database utilities class TrainingDatabaseUtils: """Training service specific database utilities""" @@ -223,27 +283,118 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: # Database initialization for training service async def initialize_training_database(): - """Initialize database tables for training service""" + """Initialize database tables for training service with retry logic and verification""" + import asyncio + from sqlalchemy import text + + max_retries = 5 + retry_delay = 2.0 + + for attempt in range(1, max_retries + 1): + try: + logger.info("Initializing training service database", + attempt=attempt, max_retries=max_retries) + + # Step 1: Test database connectivity first + logger.info("Testing database connectivity...") + connection_ok = await database_manager.test_connection() + if not connection_ok: + raise Exception("Database connection test failed") + logger.info("Database connectivity verified") + + # Step 2: Import models to ensure they're registered + logger.info("Importing and registering database models...") + from app.models.training import ( + ModelTrainingLog, + TrainedModel, + ModelPerformanceMetric, + TrainingJobQueue, + ModelArtifact + ) + + # Verify models are registered in metadata + expected_tables = { + 'model_training_logs', 'trained_models', 'model_performance_metrics', + 'training_job_queue', 'model_artifacts' + } + registered_tables = set(Base.metadata.tables.keys()) + missing_tables = expected_tables - registered_tables + if missing_tables: + raise Exception(f"Models not properly registered: {missing_tables}") + + logger.info("Models registered successfully", + tables=list(registered_tables)) + + # Step 3: Create tables using shared infrastructure with verification + logger.info("Creating database tables...") + await database_manager.create_tables() + + # Step 4: Verify tables were actually created + logger.info("Verifying table creation...") + verification_successful = await _verify_tables_exist() + + if not verification_successful: + raise Exception("Table verification failed - tables were not created properly") + + logger.info("Training service database initialized and verified successfully", + attempt=attempt) + return + + except Exception as e: + logger.error("Database initialization failed", + attempt=attempt, + max_retries=max_retries, + error=str(e)) + + if attempt == max_retries: + logger.error("All database initialization attempts failed - giving up") + raise Exception(f"Failed to initialize training database after {max_retries} attempts: {str(e)}") + + # Wait before retry with exponential backoff + wait_time = retry_delay * (2 ** (attempt - 1)) + logger.info("Retrying database initialization", + retry_in_seconds=wait_time, + next_attempt=attempt + 1) + await asyncio.sleep(wait_time) + +async def _verify_tables_exist() -> bool: + """Verify that all required tables exist in the database""" try: - logger.info("Initializing training service database") - - # Import models to ensure they're registered - from app.models.training import ( - ModelTrainingLog, - TrainedModel, - ModelPerformanceMetric, - TrainingJobQueue, - ModelArtifact - ) - - # Create tables using shared infrastructure - await database_manager.create_tables() - - logger.info("Training service database initialized successfully") - + async with database_manager.get_session() as session: + # Check each required table exists and is accessible + required_tables = [ + 'model_training_logs', + 'trained_models', + 'model_performance_metrics', + 'training_job_queue', + 'model_artifacts' + ] + + for table_name in required_tables: + try: + # Try to query the table structure + result = await session.execute( + text(f"SELECT 1 FROM {table_name} LIMIT 1") + ) + logger.debug(f"Table {table_name} exists and is accessible") + except Exception as table_error: + # If it's a "relation does not exist" error, table creation failed + if "does not exist" in str(table_error).lower(): + logger.error(f"Table {table_name} does not exist", error=str(table_error)) + return False + # If it's an empty table, that's fine - table exists + elif "no data" in str(table_error).lower(): + logger.debug(f"Table {table_name} exists but is empty (normal)") + else: + logger.warning(f"Unexpected error querying {table_name}", error=str(table_error)) + + logger.info("All required tables verified successfully", + tables=required_tables) + return True + except Exception as e: - logger.error("Failed to initialize training service database", error=str(e)) - raise + logger.error("Table verification failed", error=str(e)) + return False # Database cleanup for training service async def cleanup_training_database(): diff --git a/services/training/app/main.py b/services/training/app/main.py index e1bd0c3d..32bab02d 100644 --- a/services/training/app/main.py +++ b/services/training/app/main.py @@ -16,7 +16,7 @@ from fastapi.responses import JSONResponse import uvicorn from app.core.config import settings -from app.core.database import initialize_training_database, cleanup_training_database, get_db_health +from app.core.database import initialize_training_database, cleanup_training_database, get_db_health, get_comprehensive_db_health from app.api import training, models from app.api.websocket import websocket_router @@ -195,18 +195,69 @@ async def health_check(): @app.get("/health/ready") async def readiness_check(): - """Kubernetes readiness probe endpoint""" - checks = { - "database": await get_db_health(), - "application": getattr(app.state, 'ready', False) - } - - if all(checks.values()): - return {"status": "ready", "checks": checks} - else: + """Kubernetes readiness probe endpoint with comprehensive database checks""" + try: + # Get comprehensive database health including table verification + db_health = await get_comprehensive_db_health() + + checks = { + "database_connectivity": db_health["connectivity"], + "database_tables": db_health["tables_exist"], + "application": getattr(app.state, 'ready', False) + } + + # Include detailed database info for debugging + database_details = { + "status": db_health["status"], + "tables_verified": db_health["tables_verified"], + "missing_tables": db_health["missing_tables"], + "errors": db_health["errors"] + } + + # Service is ready only if all checks pass + all_ready = all(checks.values()) and db_health["status"] == "healthy" + + if all_ready: + return { + "status": "ready", + "checks": checks, + "database": database_details + } + else: + return JSONResponse( + status_code=503, + content={ + "status": "not ready", + "checks": checks, + "database": database_details + } + ) + + except Exception as e: + logger.error("Readiness check failed", error=str(e)) return JSONResponse( status_code=503, - content={"status": "not ready", "checks": checks} + content={ + "status": "not ready", + "error": f"Health check failed: {str(e)}" + } + ) + +@app.get("/health/database") +async def database_health_check(): + """Detailed database health endpoint for debugging""" + try: + db_health = await get_comprehensive_db_health() + status_code = 200 if db_health["status"] == "healthy" else 503 + return JSONResponse(status_code=status_code, content=db_health) + except Exception as e: + logger.error("Database health check failed", error=str(e)) + return JSONResponse( + status_code=503, + content={ + "status": "unhealthy", + "error": f"Health check failed: {str(e)}" + } ) @app.get("/metrics") @@ -220,11 +271,6 @@ async def get_metrics(): async def liveness_check(): return {"status": "alive"} -@app.get("/health/ready") -async def readiness_check(): - ready = getattr(app.state, 'ready', True) - return {"status": "ready" if ready else "not ready"} - @app.get("/") async def root(): return {"service": "training-service", "version": "1.0.0"} diff --git a/shared/database/base.py b/shared/database/base.py index 3ee6cec8..dd856343 100644 --- a/shared/database/base.py +++ b/shared/database/base.py @@ -150,12 +150,33 @@ class DatabaseManager: # ===== TABLE MANAGEMENT ===== async def create_tables(self, metadata=None): - """Create database tables""" + """Create database tables with enhanced error handling and transaction verification""" try: target_metadata = metadata or Base.metadata + table_names = list(target_metadata.tables.keys()) + logger.info(f"Creating tables: {table_names}", service=self.service_name) + + # Use explicit transaction with proper error handling async with self.async_engine.begin() as conn: - await conn.run_sync(target_metadata.create_all, checkfirst=True) + try: + # Create tables within the transaction + await conn.run_sync(target_metadata.create_all, checkfirst=True) + + # Verify transaction is not in error state + # Try a simple query to ensure connection is still valid + await conn.execute(text("SELECT 1")) + + logger.info("Database tables creation transaction completed successfully", + service=self.service_name, tables=table_names) + + except Exception as create_error: + logger.error(f"Error during table creation within transaction: {create_error}", + service=self.service_name) + # Re-raise to trigger transaction rollback + raise + logger.info("Database tables created successfully", service=self.service_name) + except Exception as e: # Check if it's a "relation already exists" error which can be safely ignored error_str = str(e).lower() @@ -164,6 +185,14 @@ class DatabaseManager: logger.info("Database tables creation completed (some already existed)", service=self.service_name) else: logger.error(f"Failed to create tables: {e}", service=self.service_name) + + # Check for specific transaction error indicators + if any(indicator in error_str for indicator in [ + "transaction", "rollback", "aborted", "failed sql transaction" + ]): + logger.error("Transaction-related error detected during table creation", + service=self.service_name) + raise DatabaseError(f"Table creation failed: {str(e)}") async def drop_tables(self, metadata=None):