mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-16 07:25:31 +01:00
🪨 fix: Bedrock Provider Support for Memory Agent (#11353)
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
Docker Dev Images Build / build (Dockerfile, librechat-dev, node) (push) Waiting to run
Docker Dev Images Build / build (Dockerfile.multi, librechat-dev-api, api-build) (push) Waiting to run
Sync Locize Translations & Create Translation PR / Sync Translation Keys with Locize (push) Waiting to run
Sync Locize Translations & Create Translation PR / Create Translation PR on Version Published (push) Blocked by required conditions
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
Docker Dev Images Build / build (Dockerfile, librechat-dev, node) (push) Waiting to run
Docker Dev Images Build / build (Dockerfile.multi, librechat-dev-api, api-build) (push) Waiting to run
Sync Locize Translations & Create Translation PR / Sync Translation Keys with Locize (push) Waiting to run
Sync Locize Translations & Create Translation PR / Create Translation PR on Version Published (push) Blocked by required conditions
* feat: Bedrock provider support in memory processing - Introduced support for the Bedrock provider in the memory processing logic. - Updated the handling of instructions to ensure they are included in user messages for Bedrock, while maintaining the standard approach for other providers. - Added tests to verify the correct behavior for both Bedrock and non-Bedrock providers regarding instruction handling. * refactor: Bedrock memory processing logic - Improved handling of the first message in Bedrock memory processing to ensure proper content is used. - Added logging for cases where the first message content is not a string. - Adjusted the processed messages to include the original content or fallback to a new HumanMessage if no messages are present. * feat: Enhance Bedrock configuration handling in memory processing - Added logic to set the temperature to 1 when using the Bedrock provider with thinking enabled. - Ensured compatibility with additional model request fields for improved memory processing.
This commit is contained in:
parent
b5e4c763af
commit
9562f9297a
2 changed files with 145 additions and 8 deletions
|
|
@ -1,17 +1,42 @@
|
|||
import { Types } from 'mongoose';
|
||||
import type { Response } from 'express';
|
||||
import { Run } from '@librechat/agents';
|
||||
import type { IUser } from '@librechat/data-schemas';
|
||||
import { createSafeUser } from '~/utils/env';
|
||||
import type { Response } from 'express';
|
||||
import { processMemory } from './memory';
|
||||
|
||||
jest.mock('~/stream/GenerationJobManager');
|
||||
|
||||
const mockCreateSafeUser = jest.fn((user) => ({
|
||||
id: user?.id,
|
||||
email: user?.email,
|
||||
name: user?.name,
|
||||
username: user?.username,
|
||||
}));
|
||||
|
||||
const mockResolveHeaders = jest.fn((opts) => {
|
||||
const headers = opts.headers || {};
|
||||
const user = opts.user || {};
|
||||
const result: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
let resolved = value as string;
|
||||
resolved = resolved.replace(/\$\{(\w+)\}/g, (_match, envVar) => process.env[envVar] || '');
|
||||
resolved = resolved.replace(/\{\{LIBRECHAT_USER_EMAIL\}\}/g, user.email || '');
|
||||
resolved = resolved.replace(/\{\{LIBRECHAT_USER_ID\}\}/g, user.id || '');
|
||||
result[key] = resolved;
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
jest.mock('~/utils', () => ({
|
||||
Tokenizer: {
|
||||
getTokenCount: jest.fn(() => 10),
|
||||
},
|
||||
createSafeUser: (user: unknown) => mockCreateSafeUser(user),
|
||||
resolveHeaders: (opts: unknown) => mockResolveHeaders(opts),
|
||||
}));
|
||||
|
||||
const { createSafeUser } = jest.requireMock('~/utils');
|
||||
|
||||
jest.mock('@librechat/agents', () => ({
|
||||
Run: {
|
||||
create: jest.fn(() => ({
|
||||
|
|
@ -20,6 +45,7 @@ jest.mock('@librechat/agents', () => ({
|
|||
},
|
||||
Providers: {
|
||||
OPENAI: 'openai',
|
||||
BEDROCK: 'bedrock',
|
||||
},
|
||||
GraphEvents: {
|
||||
TOOL_END: 'tool_end',
|
||||
|
|
@ -295,4 +321,65 @@ describe('Memory Agent Header Resolution', () => {
|
|||
expect(safeUser).toHaveProperty('id');
|
||||
expect(safeUser).toHaveProperty('email');
|
||||
});
|
||||
|
||||
it('should include instructions in user message for Bedrock provider', async () => {
|
||||
const llmConfig = {
|
||||
provider: 'bedrock',
|
||||
model: 'us.anthropic.claude-haiku-4-5-20251001-v1:0',
|
||||
};
|
||||
|
||||
const { HumanMessage } = await import('@langchain/core/messages');
|
||||
const testMessage = new HumanMessage('test chat content');
|
||||
|
||||
await processMemory({
|
||||
res: mockRes,
|
||||
userId: 'user-123',
|
||||
setMemory: mockMemoryMethods.setMemory,
|
||||
deleteMemory: mockMemoryMethods.deleteMemory,
|
||||
messages: [testMessage],
|
||||
memory: 'existing memory',
|
||||
messageId: 'msg-123',
|
||||
conversationId: 'conv-123',
|
||||
validKeys: ['preferences'],
|
||||
instructions: 'test instructions',
|
||||
llmConfig,
|
||||
user: testUser,
|
||||
});
|
||||
|
||||
expect(Run.create as jest.Mock).toHaveBeenCalled();
|
||||
const runConfig = (Run.create as jest.Mock).mock.calls[0][0];
|
||||
|
||||
// For Bedrock, instructions should NOT be passed to graphConfig
|
||||
expect(runConfig.graphConfig.instructions).toBeUndefined();
|
||||
expect(runConfig.graphConfig.additional_instructions).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should pass instructions to graphConfig for non-Bedrock providers', async () => {
|
||||
const llmConfig = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4o-mini',
|
||||
};
|
||||
|
||||
await processMemory({
|
||||
res: mockRes,
|
||||
userId: 'user-123',
|
||||
setMemory: mockMemoryMethods.setMemory,
|
||||
deleteMemory: mockMemoryMethods.deleteMemory,
|
||||
messages: [],
|
||||
memory: 'existing memory',
|
||||
messageId: 'msg-123',
|
||||
conversationId: 'conv-123',
|
||||
validKeys: ['preferences'],
|
||||
instructions: 'test instructions',
|
||||
llmConfig,
|
||||
user: testUser,
|
||||
});
|
||||
|
||||
expect(Run.create as jest.Mock).toHaveBeenCalled();
|
||||
const runConfig = (Run.create as jest.Mock).mock.calls[0][0];
|
||||
|
||||
// For non-Bedrock providers, instructions should be passed to graphConfig
|
||||
expect(runConfig.graphConfig.instructions).toBe('test instructions');
|
||||
expect(runConfig.graphConfig.additional_instructions).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import { z } from 'zod';
|
|||
import { tool } from '@langchain/core/tools';
|
||||
import { Tools } from 'librechat-data-provider';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { HumanMessage } from '@langchain/core/messages';
|
||||
import { Run, Providers, GraphEvents } from '@librechat/agents';
|
||||
import type {
|
||||
OpenAIClientOptions,
|
||||
|
|
@ -13,13 +14,12 @@ import type {
|
|||
ToolEndData,
|
||||
LLMConfig,
|
||||
} from '@librechat/agents';
|
||||
import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
|
||||
import type { ObjectId, MemoryMethods, IUser } from '@librechat/data-schemas';
|
||||
import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
|
||||
import type { BaseMessage, ToolMessage } from '@langchain/core/messages';
|
||||
import type { Response as ServerResponse } from 'express';
|
||||
import { GenerationJobManager } from '~/stream/GenerationJobManager';
|
||||
import { resolveHeaders, createSafeUser } from '~/utils/env';
|
||||
import { Tokenizer } from '~/utils';
|
||||
import { Tokenizer, resolveHeaders, createSafeUser } from '~/utils';
|
||||
|
||||
type RequiredMemoryMethods = Pick<
|
||||
MemoryMethods,
|
||||
|
|
@ -369,6 +369,19 @@ ${memory ?? 'No existing memories'}`;
|
|||
}
|
||||
}
|
||||
|
||||
// Handle Bedrock with thinking enabled - temperature must be 1
|
||||
const bedrockConfig = finalLLMConfig as {
|
||||
additionalModelRequestFields?: { thinking?: unknown };
|
||||
temperature?: number;
|
||||
};
|
||||
if (
|
||||
llmConfig?.provider === Providers.BEDROCK &&
|
||||
bedrockConfig.additionalModelRequestFields?.thinking != null &&
|
||||
bedrockConfig.temperature != null
|
||||
) {
|
||||
(finalLLMConfig as unknown as Record<string, unknown>).temperature = 1;
|
||||
}
|
||||
|
||||
const llmConfigWithHeaders = finalLLMConfig as OpenAIClientOptions;
|
||||
if (llmConfigWithHeaders?.configuration?.defaultHeaders != null) {
|
||||
llmConfigWithHeaders.configuration.defaultHeaders = resolveHeaders({
|
||||
|
|
@ -383,14 +396,51 @@ ${memory ?? 'No existing memories'}`;
|
|||
[GraphEvents.TOOL_END]: new BasicToolEndHandler(memoryCallback),
|
||||
};
|
||||
|
||||
/**
|
||||
* For Bedrock provider, include instructions in the user message instead of as a system prompt.
|
||||
* Bedrock's Converse API requires conversations to start with a user message, not a system message.
|
||||
* Other providers can use the standard system prompt approach.
|
||||
*/
|
||||
const isBedrock = llmConfig?.provider === Providers.BEDROCK;
|
||||
|
||||
let graphInstructions: string | undefined = instructions;
|
||||
let graphAdditionalInstructions: string | undefined = memoryStatus;
|
||||
let processedMessages = messages;
|
||||
|
||||
if (isBedrock) {
|
||||
const combinedInstructions = [instructions, memoryStatus].filter(Boolean).join('\n\n');
|
||||
|
||||
if (messages.length > 0) {
|
||||
const firstMessage = messages[0];
|
||||
const originalContent =
|
||||
typeof firstMessage.content === 'string' ? firstMessage.content : '';
|
||||
|
||||
if (typeof firstMessage.content !== 'string') {
|
||||
logger.warn(
|
||||
'Bedrock memory processing: First message has non-string content, using empty string',
|
||||
);
|
||||
}
|
||||
|
||||
const bedrockUserMessage = new HumanMessage(
|
||||
`${combinedInstructions}\n\n${originalContent}`,
|
||||
);
|
||||
processedMessages = [bedrockUserMessage, ...messages.slice(1)];
|
||||
} else {
|
||||
processedMessages = [new HumanMessage(combinedInstructions)];
|
||||
}
|
||||
|
||||
graphInstructions = undefined;
|
||||
graphAdditionalInstructions = undefined;
|
||||
}
|
||||
|
||||
const run = await Run.create({
|
||||
runId: messageId,
|
||||
graphConfig: {
|
||||
type: 'standard',
|
||||
llmConfig: finalLLMConfig,
|
||||
tools: [memoryTool, deleteMemoryTool],
|
||||
instructions,
|
||||
additional_instructions: memoryStatus,
|
||||
instructions: graphInstructions,
|
||||
additional_instructions: graphAdditionalInstructions,
|
||||
toolEnd: true,
|
||||
},
|
||||
customHandlers,
|
||||
|
|
@ -410,7 +460,7 @@ ${memory ?? 'No existing memories'}`;
|
|||
} as const;
|
||||
|
||||
const inputs = {
|
||||
messages,
|
||||
messages: processedMessages,
|
||||
};
|
||||
const content = await run.processStream(inputs, config);
|
||||
if (content) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue