diff --git a/client/src/hooks/AuthContext.tsx b/client/src/hooks/AuthContext.tsx index ca82e10f8f..a0613f113c 100644 --- a/client/src/hooks/AuthContext.tsx +++ b/client/src/hooks/AuthContext.tsx @@ -20,7 +20,7 @@ import { useLogoutUserMutation, useRefreshTokenMutation, } from '~/data-provider'; -import { isSafeRedirect, buildLoginRedirectUrl, getPostLoginRedirect } from '~/utils'; +import { SESSION_KEY, isSafeRedirect, buildLoginRedirectUrl, getPostLoginRedirect } from '~/utils'; import { TAuthConfig, TUserContext, TAuthContext, TResError } from '~/common'; import useTimeout from './useTimeout'; import store from '~/store'; @@ -166,7 +166,14 @@ const AuthContextProvider = ({ } const { user, token = '' } = data ?? {}; if (token) { - setUserContext({ token, isAuthenticated: true, user }); + const storedRedirect = sessionStorage.getItem(SESSION_KEY); + sessionStorage.removeItem(SESSION_KEY); + setUserContext({ + user, + token, + isAuthenticated: true, + redirect: storedRedirect && isSafeRedirect(storedRedirect) ? storedRedirect : '/c/new', + }); return; } console.log('Token is not present. User is not authenticated.'); diff --git a/client/src/hooks/__tests__/AuthContext.spec.tsx b/client/src/hooks/__tests__/AuthContext.spec.tsx index 20af37e3f2..4819f0f6d4 100644 --- a/client/src/hooks/__tests__/AuthContext.spec.tsx +++ b/client/src/hooks/__tests__/AuthContext.spec.tsx @@ -7,6 +7,7 @@ import { MemoryRouter } from 'react-router-dom'; import type { TAuthConfig } from '~/common'; import { AuthContextProvider, useAuthContext } from '../AuthContext'; +import { SESSION_KEY } from '~/utils'; const mockNavigate = jest.fn(); jest.mock('react-router-dom', () => ({ @@ -274,6 +275,143 @@ describe('AuthContextProvider — logout onSuccess/onError handling', () => { expect(window.location.replace).toHaveBeenCalled(); expect(mockRefreshMutate).not.toHaveBeenCalled(); }); +}); + +describe('AuthContextProvider — silentRefresh post-login redirect', () => { + beforeEach(() => { + jest.clearAllMocks(); + sessionStorage.clear(); + }); + + afterEach(() => { + sessionStorage.clear(); + }); + + it('navigates to stored sessionStorage redirect after successful token refresh', () => { + jest.useFakeTimers(); + sessionStorage.setItem(SESSION_KEY, '/c/new?endpoint=bedrock&model=claude-sonnet-4-6'); + + renderProviderLive(); + + expect(mockRefreshMutate).toHaveBeenCalledTimes(1); + const [, refreshOptions] = mockRefreshMutate.mock.calls[0] as [ + unknown, + { onSuccess: (data: unknown) => void }, + ]; + + act(() => { + refreshOptions.onSuccess({ user: { id: '1', role: 'USER' }, token: 'new-token' }); + }); + act(() => { + jest.advanceTimersByTime(100); + }); + + expect(mockNavigate).toHaveBeenCalledWith('/c/new?endpoint=bedrock&model=claude-sonnet-4-6', { + replace: true, + }); + expect(sessionStorage.getItem(SESSION_KEY)).toBeNull(); + jest.useRealTimers(); + }); + + it('navigates to /c/new when no stored redirect exists', () => { + jest.useFakeTimers(); + + renderProviderLive(); + + expect(mockRefreshMutate).toHaveBeenCalledTimes(1); + const [, refreshOptions] = mockRefreshMutate.mock.calls[0] as [ + unknown, + { onSuccess: (data: unknown) => void }, + ]; + + act(() => { + refreshOptions.onSuccess({ user: { id: '1', role: 'USER' }, token: 'new-token' }); + }); + act(() => { + jest.advanceTimersByTime(100); + }); + + expect(mockNavigate).toHaveBeenCalledWith('/c/new', { replace: true }); + jest.useRealTimers(); + }); + + it('does not re-trigger silentRefresh after successful redirect', () => { + jest.useFakeTimers(); + sessionStorage.setItem(SESSION_KEY, '/c/abc?endpoint=bedrock'); + + renderProviderLive(); + + expect(mockRefreshMutate).toHaveBeenCalledTimes(1); + const [, refreshOptions] = mockRefreshMutate.mock.calls[0] as [ + unknown, + { onSuccess: (data: unknown) => void }, + ]; + mockRefreshMutate.mockClear(); + + act(() => { + refreshOptions.onSuccess({ user: { id: '1', role: 'USER' }, token: 'new-token' }); + }); + act(() => { + jest.advanceTimersByTime(100); + }); + + expect(mockNavigate).toHaveBeenCalledTimes(1); + expect(mockNavigate).toHaveBeenCalledWith('/c/abc?endpoint=bedrock', { replace: true }); + expect(mockRefreshMutate).not.toHaveBeenCalled(); + jest.useRealTimers(); + }); + + it('falls back to /c/new for unsafe stored redirect', () => { + jest.useFakeTimers(); + sessionStorage.setItem(SESSION_KEY, 'https://evil.com/steal'); + + renderProviderLive(); + + expect(mockRefreshMutate).toHaveBeenCalledTimes(1); + const [, refreshOptions] = mockRefreshMutate.mock.calls[0] as [ + unknown, + { onSuccess: (data: unknown) => void }, + ]; + + act(() => { + refreshOptions.onSuccess({ user: { id: '1', role: 'USER' }, token: 'new-token' }); + }); + act(() => { + jest.advanceTimersByTime(100); + }); + + expect(mockNavigate).toHaveBeenCalledWith('/c/new', { replace: true }); + expect(mockNavigate).not.toHaveBeenCalledWith('https://evil.com/steal', expect.anything()); + expect(sessionStorage.getItem(SESSION_KEY)).toBeNull(); + jest.useRealTimers(); + }); +}); + +describe('AuthContextProvider — logout error handling', () => { + const originalLocation = window.location; + + beforeEach(() => { + jest.clearAllMocks(); + Object.defineProperty(window, 'location', { + value: { + ...originalLocation, + pathname: '/c/some-chat', + search: '', + hash: '', + replace: jest.fn(), + }, + writable: true, + configurable: true, + }); + }); + + afterEach(() => { + Object.defineProperty(window, 'location', { + value: originalLocation, + writable: true, + configurable: true, + }); + }); it('clears auth state on logout error without external redirect', () => { jest.useFakeTimers(); diff --git a/client/src/routes/__tests__/StartupLayout.spec.tsx b/client/src/routes/__tests__/StartupLayout.spec.tsx index 8d2c183137..3e64d19cf2 100644 --- a/client/src/routes/__tests__/StartupLayout.spec.tsx +++ b/client/src/routes/__tests__/StartupLayout.spec.tsx @@ -2,8 +2,8 @@ import React from 'react'; import { render, waitFor } from '@testing-library/react'; import { createMemoryRouter, RouterProvider } from 'react-router-dom'; +import StartupLayout from '~/routes/Layouts/Startup'; import { SESSION_KEY } from '~/utils'; -import StartupLayout from '../Layouts/Startup'; if (typeof Request === 'undefined') { global.Request = class Request {