mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-03-04 23:30:19 +01:00
Merge branch 'main' into feature/entra-id-azure-integration
This commit is contained in:
commit
631f4b3703
151 changed files with 3677 additions and 1242 deletions
68
packages/api/src/files/context.ts
Normal file
68
packages/api/src/files/context.ts
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { FileSources, mergeFileConfig } from 'librechat-data-provider';
|
||||
import type { fileConfigSchema } from 'librechat-data-provider';
|
||||
import type { IMongoFile } from '@librechat/data-schemas';
|
||||
import type { z } from 'zod';
|
||||
import { processTextWithTokenLimit } from '~/utils/text';
|
||||
|
||||
/**
|
||||
* Extracts text context from attachments and returns formatted text.
|
||||
* This handles text that was already extracted from files (OCR, transcriptions, document text, etc.)
|
||||
* @param params - The parameters object
|
||||
* @param params.attachments - Array of file attachments
|
||||
* @param params.req - Express request object for config access
|
||||
* @param params.tokenCountFn - Function to count tokens in text
|
||||
* @returns The formatted file context text, or undefined if no text found
|
||||
*/
|
||||
export async function extractFileContext({
|
||||
attachments,
|
||||
req,
|
||||
tokenCountFn,
|
||||
}: {
|
||||
attachments: IMongoFile[];
|
||||
req?: {
|
||||
body?: { fileTokenLimit?: number };
|
||||
config?: { fileConfig?: z.infer<typeof fileConfigSchema> };
|
||||
};
|
||||
tokenCountFn: (text: string) => number;
|
||||
}): Promise<string | undefined> {
|
||||
if (!attachments || attachments.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const fileConfig = mergeFileConfig(req?.config?.fileConfig);
|
||||
const fileTokenLimit = req?.body?.fileTokenLimit ?? fileConfig.fileTokenLimit;
|
||||
|
||||
if (!fileTokenLimit) {
|
||||
// If no token limit, return undefined (no processing)
|
||||
return undefined;
|
||||
}
|
||||
|
||||
let resultText = '';
|
||||
|
||||
for (const file of attachments) {
|
||||
const source = file.source ?? FileSources.local;
|
||||
if (source === FileSources.text && file.text) {
|
||||
const { text: limitedText, wasTruncated } = await processTextWithTokenLimit({
|
||||
text: file.text,
|
||||
tokenLimit: fileTokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
if (wasTruncated) {
|
||||
logger.debug(
|
||||
`[extractFileContext] Text content truncated for file: ${file.filename} due to token limits`,
|
||||
);
|
||||
}
|
||||
|
||||
resultText += `${!resultText ? 'Attached document(s):\n```md' : '\n\n---\n\n'}# "${file.filename}"\n${limitedText}\n`;
|
||||
}
|
||||
}
|
||||
|
||||
if (resultText) {
|
||||
resultText += '\n```';
|
||||
return resultText;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
export * from './audio';
|
||||
export * from './context';
|
||||
export * from './encode';
|
||||
export * from './mistral/crud';
|
||||
export * from './ocr';
|
||||
|
|
|
|||
|
|
@ -142,6 +142,7 @@ export class MCPConnectionFactory {
|
|||
serverName: metadata.serverName,
|
||||
clientInfo: metadata.clientInfo,
|
||||
},
|
||||
this.serverConfig.oauth_headers ?? {},
|
||||
this.serverConfig.oauth,
|
||||
);
|
||||
};
|
||||
|
|
@ -161,6 +162,7 @@ export class MCPConnectionFactory {
|
|||
this.serverName,
|
||||
data.serverUrl || '',
|
||||
this.userId!,
|
||||
config?.oauth_headers ?? {},
|
||||
config?.oauth,
|
||||
);
|
||||
|
||||
|
|
@ -358,6 +360,7 @@ export class MCPConnectionFactory {
|
|||
this.serverName,
|
||||
serverUrl,
|
||||
this.userId!,
|
||||
this.serverConfig.oauth_headers ?? {},
|
||||
this.serverConfig.oauth,
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -255,6 +255,7 @@ describe('MCPConnectionFactory', () => {
|
|||
'test-server',
|
||||
'https://api.example.com',
|
||||
'user123',
|
||||
{},
|
||||
undefined,
|
||||
);
|
||||
expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com');
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import type { MCPOptions } from 'librechat-data-provider';
|
||||
import type { AuthorizationServerMetadata } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import { MCPOAuthHandler } from '~/mcp/oauth';
|
||||
import { MCPOAuthFlowMetadata, MCPOAuthHandler, MCPOAuthTokens } from '~/mcp/oauth';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
|
|
@ -14,18 +14,33 @@ jest.mock('@librechat/data-schemas', () => ({
|
|||
jest.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({
|
||||
startAuthorization: jest.fn(),
|
||||
discoverAuthorizationServerMetadata: jest.fn(),
|
||||
discoverOAuthProtectedResourceMetadata: jest.fn(),
|
||||
registerClient: jest.fn(),
|
||||
exchangeAuthorization: jest.fn(),
|
||||
}));
|
||||
|
||||
import {
|
||||
startAuthorization,
|
||||
discoverAuthorizationServerMetadata,
|
||||
discoverOAuthProtectedResourceMetadata,
|
||||
registerClient,
|
||||
exchangeAuthorization,
|
||||
} from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
import { FlowStateManager } from '../../flow/manager';
|
||||
|
||||
const mockStartAuthorization = startAuthorization as jest.MockedFunction<typeof startAuthorization>;
|
||||
const mockDiscoverAuthorizationServerMetadata =
|
||||
discoverAuthorizationServerMetadata as jest.MockedFunction<
|
||||
typeof discoverAuthorizationServerMetadata
|
||||
>;
|
||||
const mockDiscoverOAuthProtectedResourceMetadata =
|
||||
discoverOAuthProtectedResourceMetadata as jest.MockedFunction<
|
||||
typeof discoverOAuthProtectedResourceMetadata
|
||||
>;
|
||||
const mockRegisterClient = registerClient as jest.MockedFunction<typeof registerClient>;
|
||||
const mockExchangeAuthorization = exchangeAuthorization as jest.MockedFunction<
|
||||
typeof exchangeAuthorization
|
||||
>;
|
||||
|
||||
describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||
const mockServerName = 'test-server';
|
||||
|
|
@ -60,6 +75,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
mockServerName,
|
||||
mockServerUrl,
|
||||
mockUserId,
|
||||
{},
|
||||
baseConfig,
|
||||
);
|
||||
|
||||
|
|
@ -82,7 +98,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
grant_types_supported: ['authorization_code'],
|
||||
};
|
||||
|
||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
||||
await MCPOAuthHandler.initiateOAuthFlow(
|
||||
mockServerName,
|
||||
mockServerUrl,
|
||||
mockUserId,
|
||||
{},
|
||||
config,
|
||||
);
|
||||
|
||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||
mockServerUrl,
|
||||
|
|
@ -100,7 +122,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
token_endpoint_auth_methods_supported: ['client_secret_post'],
|
||||
};
|
||||
|
||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
||||
await MCPOAuthHandler.initiateOAuthFlow(
|
||||
mockServerName,
|
||||
mockServerUrl,
|
||||
mockUserId,
|
||||
{},
|
||||
config,
|
||||
);
|
||||
|
||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||
mockServerUrl,
|
||||
|
|
@ -118,7 +146,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
response_types_supported: ['code', 'token'],
|
||||
};
|
||||
|
||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
||||
await MCPOAuthHandler.initiateOAuthFlow(
|
||||
mockServerName,
|
||||
mockServerUrl,
|
||||
mockUserId,
|
||||
{},
|
||||
config,
|
||||
);
|
||||
|
||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||
mockServerUrl,
|
||||
|
|
@ -136,7 +170,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
code_challenge_methods_supported: ['S256'],
|
||||
};
|
||||
|
||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
||||
await MCPOAuthHandler.initiateOAuthFlow(
|
||||
mockServerName,
|
||||
mockServerUrl,
|
||||
mockUserId,
|
||||
{},
|
||||
config,
|
||||
);
|
||||
|
||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||
mockServerUrl,
|
||||
|
|
@ -157,7 +197,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
code_challenge_methods_supported: ['S256'],
|
||||
};
|
||||
|
||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
||||
await MCPOAuthHandler.initiateOAuthFlow(
|
||||
mockServerName,
|
||||
mockServerUrl,
|
||||
mockUserId,
|
||||
{},
|
||||
config,
|
||||
);
|
||||
|
||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||
mockServerUrl,
|
||||
|
|
@ -181,7 +227,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
code_challenge_methods_supported: [],
|
||||
};
|
||||
|
||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
||||
await MCPOAuthHandler.initiateOAuthFlow(
|
||||
mockServerName,
|
||||
mockServerUrl,
|
||||
mockUserId,
|
||||
{},
|
||||
config,
|
||||
);
|
||||
|
||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||
mockServerUrl,
|
||||
|
|
@ -251,7 +303,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
}),
|
||||
} as Response);
|
||||
|
||||
const result = await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata);
|
||||
const result = await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {});
|
||||
|
||||
// Verify the call was made without Authorization header
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
|
|
@ -314,7 +366,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
}),
|
||||
} as Response);
|
||||
|
||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata);
|
||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {});
|
||||
|
||||
const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`;
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
|
|
@ -363,7 +415,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
}),
|
||||
} as Response);
|
||||
|
||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata);
|
||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {});
|
||||
|
||||
const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`;
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
|
|
@ -410,7 +462,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
}),
|
||||
} as Response);
|
||||
|
||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata);
|
||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {});
|
||||
|
||||
const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`;
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
|
|
@ -457,7 +509,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
}),
|
||||
} as Response);
|
||||
|
||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata);
|
||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {});
|
||||
|
||||
// Verify the call was made without Authorization header
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
|
|
@ -498,6 +550,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
await MCPOAuthHandler.refreshOAuthTokens(
|
||||
mockRefreshToken,
|
||||
{ serverName: 'test-server' },
|
||||
{},
|
||||
config,
|
||||
);
|
||||
|
||||
|
|
@ -539,6 +592,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
await MCPOAuthHandler.refreshOAuthTokens(
|
||||
mockRefreshToken,
|
||||
{ serverName: 'test-server' },
|
||||
{},
|
||||
config,
|
||||
);
|
||||
|
||||
|
|
@ -575,6 +629,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
await MCPOAuthHandler.refreshOAuthTokens(
|
||||
mockRefreshToken,
|
||||
{ serverName: 'test-server' },
|
||||
{},
|
||||
config,
|
||||
);
|
||||
|
||||
|
|
@ -617,7 +672,9 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
'{"error":"invalid_request","error_description":"refresh_token.client_id: Field required"}',
|
||||
} as Response);
|
||||
|
||||
await expect(MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata)).rejects.toThrow(
|
||||
await expect(
|
||||
MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {}),
|
||||
).rejects.toThrow(
|
||||
'Token refresh failed: 400 Bad Request - {"error":"invalid_request","error_description":"refresh_token.client_id: Field required"}',
|
||||
);
|
||||
});
|
||||
|
|
@ -813,4 +870,126 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Custom OAuth Headers', () => {
|
||||
const originalFetch = global.fetch;
|
||||
const mockFetch = jest.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
global.fetch = mockFetch as unknown as typeof fetch;
|
||||
mockFetch.mockResolvedValue({ ok: true, json: async () => ({}) } as Response);
|
||||
mockDiscoverAuthorizationServerMetadata.mockResolvedValue({
|
||||
issuer: 'http://example.com',
|
||||
authorization_endpoint: 'http://example.com/auth',
|
||||
token_endpoint: 'http://example.com/token',
|
||||
response_types_supported: ['code'],
|
||||
} as AuthorizationServerMetadata);
|
||||
mockStartAuthorization.mockResolvedValue({
|
||||
authorizationUrl: new URL('http://example.com/auth'),
|
||||
codeVerifier: 'test-verifier',
|
||||
});
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
global.fetch = originalFetch;
|
||||
});
|
||||
|
||||
it('passes headers to client registration', async () => {
|
||||
mockRegisterClient.mockImplementation(async (_, options) => {
|
||||
await options.fetchFn?.('http://example.com/register', {});
|
||||
return { client_id: 'test', redirect_uris: [] };
|
||||
});
|
||||
|
||||
await MCPOAuthHandler.initiateOAuthFlow(
|
||||
'test-server',
|
||||
'http://example.com',
|
||||
'user-123',
|
||||
{ foo: 'bar' },
|
||||
{},
|
||||
);
|
||||
|
||||
const headers = mockFetch.mock.calls[0][1]?.headers as Headers;
|
||||
expect(headers.get('foo')).toBe('bar');
|
||||
});
|
||||
|
||||
it('passes headers to discovery operations', async () => {
|
||||
mockDiscoverOAuthProtectedResourceMetadata.mockImplementation(async (_, __, fetchFn) => {
|
||||
await fetchFn?.('http://example.com/.well-known/oauth-protected-resource', {});
|
||||
return {
|
||||
resource: 'http://example.com',
|
||||
authorization_servers: ['http://auth.example.com'],
|
||||
};
|
||||
});
|
||||
|
||||
await MCPOAuthHandler.initiateOAuthFlow(
|
||||
'test-server',
|
||||
'http://example.com',
|
||||
'user-123',
|
||||
{ foo: 'bar' },
|
||||
{},
|
||||
);
|
||||
|
||||
const allHaveHeader = mockFetch.mock.calls.every((call) => {
|
||||
const headers = call[1]?.headers as Headers;
|
||||
return headers?.get('foo') === 'bar';
|
||||
});
|
||||
expect(allHaveHeader).toBe(true);
|
||||
});
|
||||
|
||||
it('passes headers to token exchange', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({
|
||||
status: 'PENDING',
|
||||
metadata: {
|
||||
serverName: 'test-server',
|
||||
codeVerifier: 'test-verifier',
|
||||
clientInfo: {},
|
||||
metadata: {},
|
||||
} as MCPOAuthFlowMetadata,
|
||||
}),
|
||||
completeFlow: jest.fn(),
|
||||
} as unknown as FlowStateManager<MCPOAuthTokens>;
|
||||
|
||||
mockExchangeAuthorization.mockImplementation(async (_, options) => {
|
||||
await options.fetchFn?.('http://example.com/token', {});
|
||||
return { access_token: 'test-token', token_type: 'Bearer', expires_in: 3600 };
|
||||
});
|
||||
|
||||
await MCPOAuthHandler.completeOAuthFlow('test-flow-id', 'test-auth-code', mockFlowManager, {
|
||||
foo: 'bar',
|
||||
});
|
||||
|
||||
const headers = mockFetch.mock.calls[0][1]?.headers as Headers;
|
||||
expect(headers.get('foo')).toBe('bar');
|
||||
});
|
||||
|
||||
it('passes headers to token refresh', async () => {
|
||||
mockDiscoverAuthorizationServerMetadata.mockImplementation(async (_, options) => {
|
||||
await options?.fetchFn?.('http://example.com/.well-known/oauth-authorization-server', {});
|
||||
return {
|
||||
issuer: 'http://example.com',
|
||||
token_endpoint: 'http://example.com/token',
|
||||
} as AuthorizationServerMetadata;
|
||||
});
|
||||
|
||||
await MCPOAuthHandler.refreshOAuthTokens(
|
||||
'test-refresh-token',
|
||||
{
|
||||
serverName: 'test-server',
|
||||
serverUrl: 'http://example.com',
|
||||
clientInfo: { client_id: 'test-client', client_secret: 'test-secret' },
|
||||
},
|
||||
{ foo: 'bar' },
|
||||
{},
|
||||
);
|
||||
|
||||
const discoveryCall = mockFetch.mock.calls.find((call) =>
|
||||
call[0].toString().includes('.well-known'),
|
||||
);
|
||||
expect(discoveryCall).toBeDefined();
|
||||
const headers = discoveryCall![1]?.headers as Headers;
|
||||
expect(headers.get('foo')).toBe('bar');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import type {
|
|||
OAuthMetadata,
|
||||
} from './types';
|
||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||
import { FetchLike } from '@modelcontextprotocol/sdk/shared/transport';
|
||||
|
||||
/** Type for the OAuth metadata from the SDK */
|
||||
type SDKOAuthMetadata = Parameters<typeof registerClient>[1]['metadata'];
|
||||
|
|
@ -26,10 +27,29 @@ export class MCPOAuthHandler {
|
|||
private static readonly FLOW_TYPE = 'mcp_oauth';
|
||||
private static readonly FLOW_TTL = 10 * 60 * 1000; // 10 minutes
|
||||
|
||||
/**
|
||||
* Creates a fetch function with custom headers injected
|
||||
*/
|
||||
private static createOAuthFetch(headers: Record<string, string>): FetchLike {
|
||||
return async (url: string | URL, init?: RequestInit): Promise<Response> => {
|
||||
const newHeaders = new Headers(init?.headers ?? {});
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
newHeaders.set(key, value);
|
||||
}
|
||||
return fetch(url, {
|
||||
...init,
|
||||
headers: newHeaders,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers OAuth metadata from the server
|
||||
*/
|
||||
private static async discoverMetadata(serverUrl: string): Promise<{
|
||||
private static async discoverMetadata(
|
||||
serverUrl: string,
|
||||
oauthHeaders: Record<string, string>,
|
||||
): Promise<{
|
||||
metadata: OAuthMetadata;
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata;
|
||||
authServerUrl: URL;
|
||||
|
|
@ -41,12 +61,14 @@ export class MCPOAuthHandler {
|
|||
let authServerUrl = new URL(serverUrl);
|
||||
let resourceMetadata: OAuthProtectedResourceMetadata | undefined;
|
||||
|
||||
const fetchFn = this.createOAuthFetch(oauthHeaders);
|
||||
|
||||
try {
|
||||
// Try to discover resource metadata first
|
||||
logger.debug(
|
||||
`[MCPOAuth] Attempting to discover protected resource metadata from ${serverUrl}`,
|
||||
);
|
||||
resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl);
|
||||
resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, {}, fetchFn);
|
||||
|
||||
if (resourceMetadata?.authorization_servers?.length) {
|
||||
authServerUrl = new URL(resourceMetadata.authorization_servers[0]);
|
||||
|
|
@ -66,7 +88,9 @@ export class MCPOAuthHandler {
|
|||
logger.debug(
|
||||
`[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
|
||||
);
|
||||
const rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl);
|
||||
const rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl, {
|
||||
fetchFn,
|
||||
});
|
||||
|
||||
if (!rawMetadata) {
|
||||
logger.error(
|
||||
|
|
@ -92,6 +116,7 @@ export class MCPOAuthHandler {
|
|||
private static async registerOAuthClient(
|
||||
serverUrl: string,
|
||||
metadata: OAuthMetadata,
|
||||
oauthHeaders: Record<string, string>,
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata,
|
||||
redirectUri?: string,
|
||||
): Promise<OAuthClientInformation> {
|
||||
|
|
@ -159,6 +184,7 @@ export class MCPOAuthHandler {
|
|||
const clientInfo = await registerClient(serverUrl, {
|
||||
metadata: metadata as unknown as SDKOAuthMetadata,
|
||||
clientMetadata,
|
||||
fetchFn: this.createOAuthFetch(oauthHeaders),
|
||||
});
|
||||
|
||||
logger.debug(
|
||||
|
|
@ -181,7 +207,8 @@ export class MCPOAuthHandler {
|
|||
serverName: string,
|
||||
serverUrl: string,
|
||||
userId: string,
|
||||
config: MCPOptions['oauth'] | undefined,
|
||||
oauthHeaders: Record<string, string>,
|
||||
config?: MCPOptions['oauth'],
|
||||
): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> {
|
||||
logger.debug(
|
||||
`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`,
|
||||
|
|
@ -259,7 +286,10 @@ export class MCPOAuthHandler {
|
|||
logger.debug(
|
||||
`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${sanitizeUrlForLogging(serverUrl)}`,
|
||||
);
|
||||
const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata(serverUrl);
|
||||
const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata(
|
||||
serverUrl,
|
||||
oauthHeaders,
|
||||
);
|
||||
|
||||
logger.debug(
|
||||
`[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`,
|
||||
|
|
@ -272,6 +302,7 @@ export class MCPOAuthHandler {
|
|||
const clientInfo = await this.registerOAuthClient(
|
||||
authServerUrl.toString(),
|
||||
metadata,
|
||||
oauthHeaders,
|
||||
resourceMetadata,
|
||||
redirectUri,
|
||||
);
|
||||
|
|
@ -365,6 +396,7 @@ export class MCPOAuthHandler {
|
|||
flowId: string,
|
||||
authorizationCode: string,
|
||||
flowManager: FlowStateManager<MCPOAuthTokens>,
|
||||
oauthHeaders: Record<string, string>,
|
||||
): Promise<MCPOAuthTokens> {
|
||||
try {
|
||||
/** Flow state which contains our metadata */
|
||||
|
|
@ -404,6 +436,7 @@ export class MCPOAuthHandler {
|
|||
codeVerifier: metadata.codeVerifier,
|
||||
authorizationCode,
|
||||
resource,
|
||||
fetchFn: this.createOAuthFetch(oauthHeaders),
|
||||
});
|
||||
|
||||
logger.debug('[MCPOAuth] Raw tokens from exchange:', {
|
||||
|
|
@ -476,6 +509,7 @@ export class MCPOAuthHandler {
|
|||
static async refreshOAuthTokens(
|
||||
refreshToken: string,
|
||||
metadata: { serverName: string; serverUrl?: string; clientInfo?: OAuthClientInformation },
|
||||
oauthHeaders: Record<string, string>,
|
||||
config?: MCPOptions['oauth'],
|
||||
): Promise<MCPOAuthTokens> {
|
||||
logger.debug(`[MCPOAuth] Refreshing tokens for ${metadata.serverName}`);
|
||||
|
|
@ -509,7 +543,9 @@ export class MCPOAuthHandler {
|
|||
throw new Error('No token URL available for refresh');
|
||||
} else {
|
||||
/** Auto-discover OAuth configuration for refresh */
|
||||
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl);
|
||||
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, {
|
||||
fetchFn: this.createOAuthFetch(oauthHeaders),
|
||||
});
|
||||
if (!oauthMetadata) {
|
||||
throw new Error('Failed to discover OAuth metadata for token refresh');
|
||||
}
|
||||
|
|
@ -533,6 +569,7 @@ export class MCPOAuthHandler {
|
|||
const headers: HeadersInit = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
...oauthHeaders,
|
||||
};
|
||||
|
||||
/** Handle authentication based on server's advertised methods */
|
||||
|
|
@ -613,6 +650,7 @@ export class MCPOAuthHandler {
|
|||
const headers: HeadersInit = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
...oauthHeaders,
|
||||
};
|
||||
|
||||
/** Handle authentication based on configured methods */
|
||||
|
|
@ -684,7 +722,9 @@ export class MCPOAuthHandler {
|
|||
}
|
||||
|
||||
/** Auto-discover OAuth configuration for refresh */
|
||||
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl);
|
||||
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, {
|
||||
fetchFn: this.createOAuthFetch(oauthHeaders),
|
||||
});
|
||||
|
||||
if (!oauthMetadata?.token_endpoint) {
|
||||
throw new Error('No token endpoint found in OAuth metadata');
|
||||
|
|
@ -700,6 +740,7 @@ export class MCPOAuthHandler {
|
|||
const headers: HeadersInit = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Accept: 'application/json',
|
||||
...oauthHeaders,
|
||||
};
|
||||
|
||||
const response = await fetch(tokenUrl, {
|
||||
|
|
@ -742,6 +783,7 @@ export class MCPOAuthHandler {
|
|||
revocationEndpoint?: string;
|
||||
revocationEndpointAuthMethodsSupported?: string[];
|
||||
},
|
||||
oauthHeaders: Record<string, string> = {},
|
||||
): Promise<void> {
|
||||
// build the revoke URL, falling back to the server URL + /revoke if no revocation endpoint is provided
|
||||
const revokeUrl: URL =
|
||||
|
|
@ -759,6 +801,7 @@ export class MCPOAuthHandler {
|
|||
// init the request headers
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
...oauthHeaders,
|
||||
};
|
||||
|
||||
// init the request body
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ interface DropdownProps {
|
|||
iconOnly?: boolean;
|
||||
renderValue?: (option: Option) => React.ReactNode;
|
||||
ariaLabel?: string;
|
||||
'aria-labelledby'?: string;
|
||||
portal?: boolean;
|
||||
}
|
||||
|
||||
|
|
@ -37,6 +38,7 @@ const Dropdown: React.FC<DropdownProps> = ({
|
|||
iconOnly = false,
|
||||
renderValue,
|
||||
ariaLabel,
|
||||
'aria-labelledby': ariaLabelledBy,
|
||||
portal = true,
|
||||
}) => {
|
||||
const handleChange = (value: string) => {
|
||||
|
|
@ -77,6 +79,7 @@ const Dropdown: React.FC<DropdownProps> = ({
|
|||
)}
|
||||
data-testid={testId}
|
||||
aria-label={ariaLabel}
|
||||
aria-labelledby={ariaLabelledBy}
|
||||
>
|
||||
<div className="flex w-full items-center gap-2">
|
||||
{icon}
|
||||
|
|
|
|||
|
|
@ -1,191 +1,225 @@
|
|||
import * as React from 'react';
|
||||
import * as DropdownMenuPrimitive from '@radix-ui/react-dropdown-menu';
|
||||
import { Check, ChevronRight, Circle } from 'lucide-react';
|
||||
import { CheckIcon, ChevronRightIcon, CircleIcon } from 'lucide-react';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
const DropdownMenu = DropdownMenuPrimitive.Root;
|
||||
function DropdownMenu({ ...props }: React.ComponentProps<typeof DropdownMenuPrimitive.Root>) {
|
||||
return <DropdownMenuPrimitive.Root data-slot="dropdown-menu" {...props} />;
|
||||
}
|
||||
|
||||
const DropdownMenuTrigger = DropdownMenuPrimitive.Trigger;
|
||||
function DropdownMenuPortal({
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.Portal>) {
|
||||
return <DropdownMenuPrimitive.Portal data-slot="dropdown-menu-portal" {...props} />;
|
||||
}
|
||||
|
||||
const DropdownMenuGroup = DropdownMenuPrimitive.Group;
|
||||
function DropdownMenuTrigger({
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.Trigger>) {
|
||||
return <DropdownMenuPrimitive.Trigger data-slot="dropdown-menu-trigger" {...props} />;
|
||||
}
|
||||
|
||||
const DropdownMenuPortal = DropdownMenuPrimitive.Portal;
|
||||
function DropdownMenuContent({
|
||||
className,
|
||||
sideOffset = 4,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.Content>) {
|
||||
return (
|
||||
<DropdownMenuPrimitive.Portal>
|
||||
<DropdownMenuPrimitive.Content
|
||||
data-slot="dropdown-menu-content"
|
||||
sideOffset={sideOffset}
|
||||
className={cn(
|
||||
'text-popover-foreground max-h-(--radix-dropdown-menu-content-available-height) origin-(--radix-dropdown-menu-content-transform-origin) z-50 min-w-[8rem] overflow-y-auto overflow-x-hidden rounded-md border border-border-light bg-surface-secondary p-1 shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
</DropdownMenuPrimitive.Portal>
|
||||
);
|
||||
}
|
||||
|
||||
const DropdownMenuSub = DropdownMenuPrimitive.Sub;
|
||||
function DropdownMenuGroup({ ...props }: React.ComponentProps<typeof DropdownMenuPrimitive.Group>) {
|
||||
return <DropdownMenuPrimitive.Group data-slot="dropdown-menu-group" {...props} />;
|
||||
}
|
||||
|
||||
const DropdownMenuRadioGroup = DropdownMenuPrimitive.RadioGroup;
|
||||
|
||||
const DropdownMenuSubTrigger = React.forwardRef<
|
||||
React.ElementRef<typeof DropdownMenuPrimitive.SubTrigger>,
|
||||
React.ComponentPropsWithoutRef<typeof DropdownMenuPrimitive.SubTrigger> & {
|
||||
inset?: boolean;
|
||||
}
|
||||
>(({ className = '', inset, children, ...props }, ref) => (
|
||||
<DropdownMenuPrimitive.SubTrigger
|
||||
ref={ref}
|
||||
className={cn(
|
||||
'flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm font-medium outline-none focus:bg-gray-100 data-[state=open]:bg-gray-100 dark:focus:bg-gray-900 dark:data-[state=open]:bg-gray-900',
|
||||
inset ? 'pl-8' : '',
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
<ChevronRight className="ml-auto h-4 w-4" />
|
||||
</DropdownMenuPrimitive.SubTrigger>
|
||||
));
|
||||
DropdownMenuSubTrigger.displayName = DropdownMenuPrimitive.SubTrigger.displayName;
|
||||
|
||||
const DropdownMenuSubContent = React.forwardRef<
|
||||
React.ElementRef<typeof DropdownMenuPrimitive.SubContent>,
|
||||
React.ComponentPropsWithoutRef<typeof DropdownMenuPrimitive.SubContent>
|
||||
>(({ className = '', ...props }, ref) => (
|
||||
<DropdownMenuPrimitive.SubContent
|
||||
ref={ref}
|
||||
className={cn(
|
||||
'z-50 min-w-[8rem] overflow-hidden rounded-md border border-gray-100 bg-white p-1 text-gray-700 shadow-md animate-in slide-in-from-left-1 dark:border-gray-800 dark:bg-gray-800 dark:text-gray-400',
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
DropdownMenuSubContent.displayName = DropdownMenuPrimitive.SubContent.displayName;
|
||||
|
||||
const DropdownMenuContent = React.forwardRef<
|
||||
React.ElementRef<typeof DropdownMenuPrimitive.Content>,
|
||||
React.ComponentPropsWithoutRef<typeof DropdownMenuPrimitive.Content>
|
||||
>(({ className = '', sideOffset = 4, ...props }, ref) => (
|
||||
<DropdownMenuPrimitive.Portal>
|
||||
<DropdownMenuPrimitive.Content
|
||||
ref={ref}
|
||||
sideOffset={sideOffset}
|
||||
function DropdownMenuItem({
|
||||
className,
|
||||
inset,
|
||||
variant = 'default',
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.Item> & {
|
||||
inset?: boolean;
|
||||
variant?: 'default' | 'destructive';
|
||||
}) {
|
||||
return (
|
||||
<DropdownMenuPrimitive.Item
|
||||
data-slot="dropdown-menu-item"
|
||||
data-inset={inset}
|
||||
data-variant={variant}
|
||||
className={cn(
|
||||
'z-50 min-w-[8rem] overflow-hidden rounded-md border border-gray-100 bg-white p-1 text-gray-700 shadow-md animate-in data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 dark:border-gray-800 dark:bg-gray-800 dark:text-gray-400',
|
||||
"data-[variant=destructive]:*:[svg]:!text-destructive outline-hidden relative flex cursor-default select-none items-center gap-2 rounded-sm px-2 py-1.5 text-sm focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[inset]:pl-8 data-[variant=destructive]:text-destructive data-[disabled]:opacity-50 data-[variant=destructive]:focus:bg-destructive/10 data-[variant=destructive]:focus:text-destructive dark:data-[variant=destructive]:focus:bg-destructive/20 [&_svg:not([class*='size-'])]:size-4 [&_svg:not([class*='text-'])]:text-muted-foreground [&_svg]:pointer-events-none [&_svg]:shrink-0",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
</DropdownMenuPrimitive.Portal>
|
||||
));
|
||||
DropdownMenuContent.displayName = DropdownMenuPrimitive.Content.displayName;
|
||||
|
||||
const DropdownMenuItem = React.forwardRef<
|
||||
React.ElementRef<typeof DropdownMenuPrimitive.Item>,
|
||||
React.ComponentPropsWithoutRef<typeof DropdownMenuPrimitive.Item> & {
|
||||
inset?: boolean;
|
||||
}
|
||||
>(({ className = '', inset, ...props }, ref) => (
|
||||
<DropdownMenuPrimitive.Item
|
||||
ref={ref}
|
||||
className={cn(
|
||||
'relative flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm font-medium outline-none focus:bg-gray-100 data-[disabled]:pointer-events-none data-[disabled]:opacity-50 dark:focus:bg-gray-900',
|
||||
inset ? 'pl-8' : '',
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
DropdownMenuItem.displayName = DropdownMenuPrimitive.Item.displayName;
|
||||
|
||||
const DropdownMenuCheckboxItem = React.forwardRef<
|
||||
React.ElementRef<typeof DropdownMenuPrimitive.CheckboxItem>,
|
||||
React.ComponentPropsWithoutRef<typeof DropdownMenuPrimitive.CheckboxItem>
|
||||
>(({ className = '', children, checked, ...props }, ref) => (
|
||||
<DropdownMenuPrimitive.CheckboxItem
|
||||
ref={ref}
|
||||
className={cn(
|
||||
'relative flex cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm font-medium outline-none focus:bg-gray-100 data-[disabled]:pointer-events-none data-[disabled]:opacity-50 dark:focus:bg-gray-900',
|
||||
className,
|
||||
)}
|
||||
checked={checked}
|
||||
{...props}
|
||||
>
|
||||
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
|
||||
<DropdownMenuPrimitive.ItemIndicator>
|
||||
<Check className="h-4 w-4" />
|
||||
</DropdownMenuPrimitive.ItemIndicator>
|
||||
</span>
|
||||
{children}
|
||||
</DropdownMenuPrimitive.CheckboxItem>
|
||||
));
|
||||
DropdownMenuCheckboxItem.displayName = DropdownMenuPrimitive.CheckboxItem.displayName;
|
||||
|
||||
const DropdownMenuRadioItem = React.forwardRef<
|
||||
React.ElementRef<typeof DropdownMenuPrimitive.RadioItem>,
|
||||
React.ComponentPropsWithoutRef<typeof DropdownMenuPrimitive.RadioItem>
|
||||
>(({ className = '', children, ...props }, ref) => (
|
||||
<DropdownMenuPrimitive.RadioItem
|
||||
ref={ref}
|
||||
className={cn(
|
||||
className,
|
||||
'relative flex cursor-pointer select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm font-medium outline-none focus:bg-gray-100 data-[disabled]:pointer-events-none data-[disabled]:opacity-50 dark:focus:bg-gray-800',
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
|
||||
<DropdownMenuPrimitive.ItemIndicator>
|
||||
<Circle className="h-2 w-2 fill-current" />
|
||||
</DropdownMenuPrimitive.ItemIndicator>
|
||||
</span>
|
||||
{children}
|
||||
</DropdownMenuPrimitive.RadioItem>
|
||||
));
|
||||
DropdownMenuRadioItem.displayName = DropdownMenuPrimitive.RadioItem.displayName;
|
||||
|
||||
const DropdownMenuLabel = React.forwardRef<
|
||||
React.ElementRef<typeof DropdownMenuPrimitive.Label>,
|
||||
React.ComponentPropsWithoutRef<typeof DropdownMenuPrimitive.Label> & {
|
||||
inset?: boolean;
|
||||
}
|
||||
>(({ className = '', inset, ...props }, ref) => (
|
||||
<DropdownMenuPrimitive.Label
|
||||
ref={ref}
|
||||
className={cn(
|
||||
'px-2 py-1.5 text-sm font-semibold text-gray-900 dark:text-gray-300',
|
||||
inset ? 'pl-8' : '',
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
DropdownMenuLabel.displayName = DropdownMenuPrimitive.Label.displayName;
|
||||
|
||||
const DropdownMenuSeparator = React.forwardRef<
|
||||
React.ElementRef<typeof DropdownMenuPrimitive.Separator>,
|
||||
React.ComponentPropsWithoutRef<typeof DropdownMenuPrimitive.Separator>
|
||||
>(({ className = '', ...props }, ref) => (
|
||||
<DropdownMenuPrimitive.Separator
|
||||
ref={ref}
|
||||
className={cn('-mx-1 my-1 h-px bg-border-medium', className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
DropdownMenuSeparator.displayName = DropdownMenuPrimitive.Separator.displayName;
|
||||
|
||||
const DropdownMenuShortcut = ({
|
||||
className = '',
|
||||
...props
|
||||
}: React.HTMLAttributes<HTMLSpanElement>) => {
|
||||
return (
|
||||
<span className={cn('ml-auto text-xs tracking-widest text-gray-500', className)} {...props} />
|
||||
);
|
||||
};
|
||||
DropdownMenuShortcut.displayName = 'DropdownMenuShortcut';
|
||||
}
|
||||
|
||||
function DropdownMenuCheckboxItem({
|
||||
className,
|
||||
children,
|
||||
checked,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.CheckboxItem>) {
|
||||
return (
|
||||
<DropdownMenuPrimitive.CheckboxItem
|
||||
data-slot="dropdown-menu-checkbox-item"
|
||||
className={cn(
|
||||
"outline-hidden relative flex cursor-default select-none items-center gap-2 rounded-sm py-1.5 pl-8 pr-2 text-sm focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg:not([class*='size-'])]:size-4 [&_svg]:pointer-events-none [&_svg]:shrink-0",
|
||||
className,
|
||||
)}
|
||||
checked={checked}
|
||||
{...props}
|
||||
>
|
||||
<span className="pointer-events-none absolute left-2 flex size-3.5 items-center justify-center">
|
||||
<DropdownMenuPrimitive.ItemIndicator>
|
||||
<CheckIcon className="size-4" />
|
||||
</DropdownMenuPrimitive.ItemIndicator>
|
||||
</span>
|
||||
{children}
|
||||
</DropdownMenuPrimitive.CheckboxItem>
|
||||
);
|
||||
}
|
||||
|
||||
function DropdownMenuRadioGroup({
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.RadioGroup>) {
|
||||
return <DropdownMenuPrimitive.RadioGroup data-slot="dropdown-menu-radio-group" {...props} />;
|
||||
}
|
||||
|
||||
function DropdownMenuRadioItem({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.RadioItem>) {
|
||||
return (
|
||||
<DropdownMenuPrimitive.RadioItem
|
||||
data-slot="dropdown-menu-radio-item"
|
||||
className={cn(
|
||||
"outline-hidden relative flex cursor-default select-none items-center gap-2 rounded-sm py-1.5 pl-8 pr-2 text-sm focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg:not([class*='size-'])]:size-4 [&_svg]:pointer-events-none [&_svg]:shrink-0",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<span className="pointer-events-none absolute left-2 flex size-3.5 items-center justify-center">
|
||||
<DropdownMenuPrimitive.ItemIndicator>
|
||||
<CircleIcon className="size-2 fill-current" />
|
||||
</DropdownMenuPrimitive.ItemIndicator>
|
||||
</span>
|
||||
{children}
|
||||
</DropdownMenuPrimitive.RadioItem>
|
||||
);
|
||||
}
|
||||
|
||||
function DropdownMenuLabel({
|
||||
className,
|
||||
inset,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.Label> & {
|
||||
inset?: boolean;
|
||||
}) {
|
||||
return (
|
||||
<DropdownMenuPrimitive.Label
|
||||
data-slot="dropdown-menu-label"
|
||||
data-inset={inset}
|
||||
className={cn('px-2 py-1.5 text-sm font-medium data-[inset]:pl-8', className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function DropdownMenuSeparator({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.Separator>) {
|
||||
return (
|
||||
<DropdownMenuPrimitive.Separator
|
||||
data-slot="dropdown-menu-separator"
|
||||
className={cn('-mx-1 my-1 h-px bg-surface-hover', className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function DropdownMenuShortcut({ className, ...props }: React.ComponentProps<'span'>) {
|
||||
return (
|
||||
<span
|
||||
data-slot="dropdown-menu-shortcut"
|
||||
className={cn('ml-auto text-xs tracking-widest text-muted-foreground', className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function DropdownMenuSub({ ...props }: React.ComponentProps<typeof DropdownMenuPrimitive.Sub>) {
|
||||
return <DropdownMenuPrimitive.Sub data-slot="dropdown-menu-sub" {...props} />;
|
||||
}
|
||||
|
||||
function DropdownMenuSubTrigger({
|
||||
className,
|
||||
inset,
|
||||
children,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.SubTrigger> & {
|
||||
inset?: boolean;
|
||||
}) {
|
||||
return (
|
||||
<DropdownMenuPrimitive.SubTrigger
|
||||
data-slot="dropdown-menu-sub-trigger"
|
||||
data-inset={inset}
|
||||
className={cn(
|
||||
'outline-hidden flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm focus:bg-accent focus:text-accent-foreground data-[state=open]:bg-accent data-[inset]:pl-8 data-[state=open]:text-accent-foreground',
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
<ChevronRightIcon className="ml-auto size-4" />
|
||||
</DropdownMenuPrimitive.SubTrigger>
|
||||
);
|
||||
}
|
||||
|
||||
function DropdownMenuSubContent({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.SubContent>) {
|
||||
return (
|
||||
<DropdownMenuPrimitive.SubContent
|
||||
data-slot="dropdown-menu-sub-content"
|
||||
className={cn(
|
||||
'text-popover-foreground origin-(--radix-dropdown-menu-content-transform-origin) z-50 min-w-[8rem] overflow-hidden rounded-md border border-border-medium bg-surface-secondary p-1 shadow-lg data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export {
|
||||
DropdownMenu,
|
||||
DropdownMenuPortal,
|
||||
DropdownMenuTrigger,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuGroup,
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuCheckboxItem,
|
||||
DropdownMenuRadioGroup,
|
||||
DropdownMenuRadioItem,
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuShortcut,
|
||||
DropdownMenuGroup,
|
||||
DropdownMenuPortal,
|
||||
DropdownMenuSub,
|
||||
DropdownMenuSubContent,
|
||||
DropdownMenuSubTrigger,
|
||||
DropdownMenuRadioGroup,
|
||||
DropdownMenuSubContent,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import { useState } from 'react';
|
||||
import { CircleHelpIcon } from 'lucide-react';
|
||||
import { HoverCard, HoverCardTrigger, HoverCardPortal, HoverCardContent } from './HoverCard';
|
||||
import { ESide } from '~/common';
|
||||
|
|
@ -8,15 +9,23 @@ type InfoHoverCardProps = {
|
|||
};
|
||||
|
||||
const InfoHoverCard = ({ side, text }: InfoHoverCardProps) => {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<HoverCard openDelay={50}>
|
||||
<HoverCardTrigger className="cursor-help">
|
||||
<CircleHelpIcon className="h-5 w-5 text-text-tertiary" />{' '}
|
||||
<HoverCard openDelay={50} open={isOpen} onOpenChange={setIsOpen}>
|
||||
<HoverCardTrigger
|
||||
tabIndex={0}
|
||||
className="inline-flex cursor-help items-center justify-center rounded-sm focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring-primary focus-visible:ring-offset-2"
|
||||
onFocus={() => setIsOpen(true)}
|
||||
onBlur={() => setIsOpen(false)}
|
||||
aria-label={text}
|
||||
>
|
||||
<CircleHelpIcon className="h-5 w-5 text-text-tertiary" aria-hidden="true" />
|
||||
</HoverCardTrigger>
|
||||
<HoverCardPortal>
|
||||
<HoverCardContent side={side} className="z-[999] w-80">
|
||||
<div className="space-y-2">
|
||||
<p className="text-sm text-text-secondary">{text}</p>
|
||||
<span className="text-sm text-text-secondary">{text}</span>
|
||||
</div>
|
||||
</HoverCardContent>
|
||||
</HoverCardPortal>
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ const Label = React.forwardRef<
|
|||
{...props}
|
||||
{...{
|
||||
className: cn(
|
||||
'block w-full break-all text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70 dark:text-gray-200',
|
||||
'block w-full break-all text-sm leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70 dark:text-gray-200',
|
||||
className,
|
||||
),
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -2,37 +2,56 @@ import * as React from 'react';
|
|||
import * as SliderPrimitive from '@radix-ui/react-slider';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
const Slider = React.forwardRef<
|
||||
React.ElementRef<typeof SliderPrimitive.Root>,
|
||||
React.ComponentPropsWithoutRef<typeof SliderPrimitive.Root> & {
|
||||
className?: string;
|
||||
onDoubleClick?: () => void;
|
||||
}
|
||||
>(({ className, onDoubleClick, ...props }, ref) => (
|
||||
<SliderPrimitive.Root
|
||||
ref={ref}
|
||||
{...props}
|
||||
{...{
|
||||
className: cn(
|
||||
'relative flex w-full cursor-pointer touch-none select-none items-center',
|
||||
className,
|
||||
),
|
||||
type SliderProps = React.ComponentPropsWithoutRef<typeof SliderPrimitive.Root> & {
|
||||
className?: string;
|
||||
onDoubleClick?: () => void;
|
||||
'aria-describedby'?: string;
|
||||
} & (
|
||||
| { 'aria-label': string; 'aria-labelledby'?: never }
|
||||
| { 'aria-labelledby': string; 'aria-label'?: never }
|
||||
| { 'aria-label': string; 'aria-labelledby': string }
|
||||
);
|
||||
|
||||
const Slider = React.forwardRef<React.ElementRef<typeof SliderPrimitive.Root>, SliderProps>(
|
||||
(
|
||||
{
|
||||
className,
|
||||
onDoubleClick,
|
||||
}}
|
||||
>
|
||||
<SliderPrimitive.Track
|
||||
{...{ className: 'relative h-2 w-full grow overflow-hidden rounded-full bg-secondary' }}
|
||||
>
|
||||
<SliderPrimitive.Range {...{ className: 'absolute h-full bg-primary' }} />
|
||||
</SliderPrimitive.Track>
|
||||
<SliderPrimitive.Thumb
|
||||
'aria-labelledby': ariaLabelledBy,
|
||||
'aria-label': ariaLabel,
|
||||
'aria-describedby': ariaDescribedBy,
|
||||
...props
|
||||
},
|
||||
ref,
|
||||
) => (
|
||||
<SliderPrimitive.Root
|
||||
ref={ref}
|
||||
{...props}
|
||||
{...{
|
||||
className:
|
||||
'block h-5 w-5 rounded-full border-2 border-primary bg-background ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50',
|
||||
className: cn(
|
||||
'relative flex w-full cursor-pointer touch-none select-none items-center',
|
||||
className,
|
||||
),
|
||||
onDoubleClick,
|
||||
}}
|
||||
/>
|
||||
</SliderPrimitive.Root>
|
||||
));
|
||||
>
|
||||
<SliderPrimitive.Track
|
||||
{...{ className: 'relative h-2 w-full grow overflow-hidden rounded-full bg-secondary' }}
|
||||
>
|
||||
<SliderPrimitive.Range {...{ className: 'absolute h-full bg-primary' }} />
|
||||
</SliderPrimitive.Track>
|
||||
<SliderPrimitive.Thumb
|
||||
{...{
|
||||
className:
|
||||
'block h-5 w-5 rounded-full border-2 border-primary bg-background ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50',
|
||||
'aria-labelledby': ariaLabelledBy,
|
||||
'aria-label': ariaLabel,
|
||||
'aria-describedby': ariaDescribedBy,
|
||||
}}
|
||||
/>
|
||||
</SliderPrimitive.Root>
|
||||
),
|
||||
);
|
||||
Slider.displayName = SliderPrimitive.Root.displayName;
|
||||
|
||||
export { Slider };
|
||||
|
|
|
|||
|
|
@ -214,6 +214,14 @@ export const bedrockEndpointSchema = baseEndpointSchema.merge(
|
|||
}),
|
||||
);
|
||||
|
||||
const modelItemSchema = z.union([
|
||||
z.string(),
|
||||
z.object({
|
||||
name: z.string(),
|
||||
description: z.string().optional(),
|
||||
}),
|
||||
]);
|
||||
|
||||
export const assistantEndpointSchema = baseEndpointSchema.merge(
|
||||
z.object({
|
||||
/* assistants specific */
|
||||
|
|
@ -239,7 +247,7 @@ export const assistantEndpointSchema = baseEndpointSchema.merge(
|
|||
apiKey: z.string().optional(),
|
||||
models: z
|
||||
.object({
|
||||
default: z.array(z.string()).min(1),
|
||||
default: z.array(modelItemSchema).min(1),
|
||||
fetch: z.boolean().optional(),
|
||||
userIdQuery: z.boolean().optional(),
|
||||
})
|
||||
|
|
@ -299,7 +307,7 @@ export const endpointSchema = baseEndpointSchema.merge(
|
|||
apiKey: z.string(),
|
||||
baseURL: z.string(),
|
||||
models: z.object({
|
||||
default: z.array(z.string()).min(1),
|
||||
default: z.array(modelItemSchema).min(1),
|
||||
fetch: z.boolean().optional(),
|
||||
userIdQuery: z.boolean().optional(),
|
||||
}),
|
||||
|
|
@ -636,6 +644,7 @@ export type TStartupConfig = {
|
|||
helpAndFaqURL: string;
|
||||
customFooter?: string;
|
||||
modelSpecs?: TSpecsConfig;
|
||||
modelDescriptions?: Record<string, Record<string, string>>;
|
||||
sharedLinksEnabled: boolean;
|
||||
publicSharedLinksEnabled: boolean;
|
||||
analyticsGtmId?: string;
|
||||
|
|
@ -669,6 +678,7 @@ export type TStartupConfig = {
|
|||
}
|
||||
>;
|
||||
mcpPlaceholder?: string;
|
||||
conversationImportMaxFileSize?: number;
|
||||
};
|
||||
|
||||
export enum OCRStrategy {
|
||||
|
|
|
|||
|
|
@ -42,8 +42,11 @@ export function getSharedLink(conversationId: string): Promise<t.TSharedLinkGetR
|
|||
return request.get(endpoints.getSharedLink(conversationId));
|
||||
}
|
||||
|
||||
export function createSharedLink(conversationId: string): Promise<t.TSharedLinkResponse> {
|
||||
return request.post(endpoints.createSharedLink(conversationId));
|
||||
export function createSharedLink(
|
||||
conversationId: string,
|
||||
targetMessageId?: string,
|
||||
): Promise<t.TSharedLinkResponse> {
|
||||
return request.post(endpoints.createSharedLink(conversationId), { targetMessageId });
|
||||
}
|
||||
|
||||
export function updateSharedLink(shareId: string): Promise<t.TSharedLinkResponse> {
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@ const BaseOptionsSchema = z.object({
|
|||
revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(),
|
||||
})
|
||||
.optional(),
|
||||
/** Custom headers to send with OAuth requests (registration, discovery, token exchange, etc.) */
|
||||
oauth_headers: z.record(z.string(), z.string()).optional(),
|
||||
customUserVars: z
|
||||
.record(
|
||||
z.string(),
|
||||
|
|
|
|||
|
|
@ -15,6 +15,13 @@ export type TModelSpec = {
|
|||
order?: number;
|
||||
default?: boolean;
|
||||
description?: string;
|
||||
/**
|
||||
* Optional group name for organizing specs in the UI selector.
|
||||
* - If it matches an endpoint name (e.g., "openAI", "groq"), the spec appears nested under that endpoint
|
||||
* - If it's a custom name (doesn't match any endpoint), it creates a separate collapsible group
|
||||
* - If omitted, the spec appears as a standalone item at the top level
|
||||
*/
|
||||
group?: string;
|
||||
showIconInMenu?: boolean;
|
||||
showIconInHeader?: boolean;
|
||||
iconURL?: string | EModelEndpoint; // Allow using project-included icons
|
||||
|
|
@ -28,6 +35,7 @@ export const tModelSpecSchema = z.object({
|
|||
order: z.number().optional(),
|
||||
default: z.boolean().optional(),
|
||||
description: z.string().optional(),
|
||||
group: z.string().optional(),
|
||||
showIconInMenu: z.boolean().optional(),
|
||||
showIconInHeader: z.boolean().optional(),
|
||||
iconURL: z.union([z.string(), eModelEndpointSchema]).optional(),
|
||||
|
|
|
|||
|
|
@ -82,6 +82,77 @@ function anonymizeMessages(messages: t.IMessage[], newConvoId: string): t.IMessa
|
|||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter messages up to and including the target message (branch-specific)
|
||||
* Similar to getMessagesUpToTargetLevel from fork utilities
|
||||
*/
|
||||
function getMessagesUpToTarget(messages: t.IMessage[], targetMessageId: string): t.IMessage[] {
|
||||
if (!messages || messages.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
// If only one message and it's the target, return it
|
||||
if (messages.length === 1 && messages[0]?.messageId === targetMessageId) {
|
||||
return messages;
|
||||
}
|
||||
|
||||
// Create a map of parentMessageId to children messages
|
||||
const parentToChildrenMap = new Map<string, t.IMessage[]>();
|
||||
for (const message of messages) {
|
||||
const parentId = message.parentMessageId || Constants.NO_PARENT;
|
||||
if (!parentToChildrenMap.has(parentId)) {
|
||||
parentToChildrenMap.set(parentId, []);
|
||||
}
|
||||
parentToChildrenMap.get(parentId)?.push(message);
|
||||
}
|
||||
|
||||
// Find the target message
|
||||
const targetMessage = messages.find((msg) => msg.messageId === targetMessageId);
|
||||
if (!targetMessage) {
|
||||
// If target not found, return all messages for backwards compatibility
|
||||
return messages;
|
||||
}
|
||||
|
||||
const visited = new Set<string>();
|
||||
const rootMessages = parentToChildrenMap.get(Constants.NO_PARENT) || [];
|
||||
let currentLevel = rootMessages.length > 0 ? [...rootMessages] : [targetMessage];
|
||||
const results = new Set<t.IMessage>(currentLevel);
|
||||
|
||||
// Check if the target message is at the root level
|
||||
if (
|
||||
currentLevel.some((msg) => msg.messageId === targetMessageId) &&
|
||||
targetMessage.parentMessageId === Constants.NO_PARENT
|
||||
) {
|
||||
return Array.from(results);
|
||||
}
|
||||
|
||||
// Iterate level by level until the target is found
|
||||
let targetFound = false;
|
||||
while (!targetFound && currentLevel.length > 0) {
|
||||
const nextLevel: t.IMessage[] = [];
|
||||
for (const node of currentLevel) {
|
||||
if (visited.has(node.messageId)) {
|
||||
continue;
|
||||
}
|
||||
visited.add(node.messageId);
|
||||
const children = parentToChildrenMap.get(node.messageId) || [];
|
||||
for (const child of children) {
|
||||
if (visited.has(child.messageId)) {
|
||||
continue;
|
||||
}
|
||||
nextLevel.push(child);
|
||||
results.add(child);
|
||||
if (child.messageId === targetMessageId) {
|
||||
targetFound = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
currentLevel = nextLevel;
|
||||
}
|
||||
|
||||
return Array.from(results);
|
||||
}
|
||||
|
||||
/** Factory function that takes mongoose instance and returns the methods */
|
||||
export function createShareMethods(mongoose: typeof import('mongoose')) {
|
||||
/**
|
||||
|
|
@ -102,6 +173,12 @@ export function createShareMethods(mongoose: typeof import('mongoose')) {
|
|||
return null;
|
||||
}
|
||||
|
||||
// Filter messages based on targetMessageId if present (branch-specific sharing)
|
||||
let messagesToShare = share.messages;
|
||||
if (share.targetMessageId) {
|
||||
messagesToShare = getMessagesUpToTarget(share.messages, share.targetMessageId);
|
||||
}
|
||||
|
||||
const newConvoId = anonymizeConvoId(share.conversationId);
|
||||
const result: t.SharedMessagesResult = {
|
||||
shareId: share.shareId || shareId,
|
||||
|
|
@ -110,7 +187,7 @@ export function createShareMethods(mongoose: typeof import('mongoose')) {
|
|||
createdAt: share.createdAt,
|
||||
updatedAt: share.updatedAt,
|
||||
conversationId: newConvoId,
|
||||
messages: anonymizeMessages(share.messages, newConvoId),
|
||||
messages: anonymizeMessages(messagesToShare, newConvoId),
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
@ -239,6 +316,7 @@ export function createShareMethods(mongoose: typeof import('mongoose')) {
|
|||
async function createSharedLink(
|
||||
user: string,
|
||||
conversationId: string,
|
||||
targetMessageId?: string,
|
||||
): Promise<t.CreateShareResult> {
|
||||
if (!user || !conversationId) {
|
||||
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
|
||||
|
|
@ -249,7 +327,12 @@ export function createShareMethods(mongoose: typeof import('mongoose')) {
|
|||
const Conversation = mongoose.models.Conversation as SchemaWithMeiliMethods;
|
||||
|
||||
const [existingShare, conversationMessages] = await Promise.all([
|
||||
SharedLink.findOne({ conversationId, user, isPublic: true })
|
||||
SharedLink.findOne({
|
||||
conversationId,
|
||||
user,
|
||||
isPublic: true,
|
||||
...(targetMessageId && { targetMessageId }),
|
||||
})
|
||||
.select('-_id -__v -user')
|
||||
.lean() as Promise<t.ISharedLink | null>,
|
||||
Message.find({ conversationId, user }).sort({ createdAt: 1 }).lean(),
|
||||
|
|
@ -259,10 +342,15 @@ export function createShareMethods(mongoose: typeof import('mongoose')) {
|
|||
logger.error('[createSharedLink] Share already exists', {
|
||||
user,
|
||||
conversationId,
|
||||
targetMessageId,
|
||||
});
|
||||
throw new ShareServiceError('Share already exists', 'SHARE_EXISTS');
|
||||
} else if (existingShare) {
|
||||
await SharedLink.deleteOne({ conversationId, user });
|
||||
await SharedLink.deleteOne({
|
||||
conversationId,
|
||||
user,
|
||||
...(targetMessageId && { targetMessageId }),
|
||||
});
|
||||
}
|
||||
|
||||
const conversation = (await Conversation.findOne({ conversationId, user }).lean()) as {
|
||||
|
|
@ -291,6 +379,7 @@ export function createShareMethods(mongoose: typeof import('mongoose')) {
|
|||
messages: conversationMessages,
|
||||
title,
|
||||
user,
|
||||
...(targetMessageId && { targetMessageId }),
|
||||
});
|
||||
|
||||
return { shareId, conversationId };
|
||||
|
|
@ -302,6 +391,7 @@ export function createShareMethods(mongoose: typeof import('mongoose')) {
|
|||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
user,
|
||||
conversationId,
|
||||
targetMessageId,
|
||||
});
|
||||
throw new ShareServiceError('Error creating shared link', 'SHARE_CREATE_ERROR');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ export interface ISharedLink extends Document {
|
|||
user?: string;
|
||||
messages?: Types.ObjectId[];
|
||||
shareId?: string;
|
||||
targetMessageId?: string;
|
||||
isPublic: boolean;
|
||||
createdAt?: Date;
|
||||
updatedAt?: Date;
|
||||
|
|
@ -30,6 +31,11 @@ const shareSchema: Schema<ISharedLink> = new Schema(
|
|||
type: String,
|
||||
index: true,
|
||||
},
|
||||
targetMessageId: {
|
||||
type: String,
|
||||
required: false,
|
||||
index: true,
|
||||
},
|
||||
isPublic: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
|
|
@ -38,4 +44,6 @@ const shareSchema: Schema<ISharedLink> = new Schema(
|
|||
{ timestamps: true },
|
||||
);
|
||||
|
||||
shareSchema.index({ conversationId: 1, user: 1, targetMessageId: 1 });
|
||||
|
||||
export default shareSchema;
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ export interface ISharedLink {
|
|||
user?: string;
|
||||
messages?: Types.ObjectId[];
|
||||
shareId?: string;
|
||||
targetMessageId?: string;
|
||||
isPublic: boolean;
|
||||
createdAt?: Date;
|
||||
updatedAt?: Date;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue