🎯 fix: Google AI Client Stability; feat: gemini-exp models (#4781)

* fix: Google timing out and issuing AbortError, bump package, and use `@google/generative-ai` explicitly for latest models

* feat: gemini-exp-
This commit is contained in:
Danny Avila 2024-11-22 19:08:14 -05:00 committed by GitHub
parent 56b60cf863
commit 2a77c98f51
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 19 additions and 22 deletions

View file

@ -35,6 +35,7 @@ const endpointPrefix = `https://${loc}-aiplatform.googleapis.com`;
const tokenizersCache = {};
const settings = endpointSettings[EModelEndpoint.google];
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;
class GoogleClient extends BaseClient {
constructor(credentials, options = {}) {
@ -366,7 +367,7 @@ class GoogleClient extends BaseClient {
);
}
if (!this.project_id && this.modelOptions.model.includes('1.5')) {
if (!this.project_id && !EXCLUDED_GENAI_MODELS.test(this.modelOptions.model)) {
return await this.buildGenerativeMessages(messages);
}
@ -604,15 +605,12 @@ class GoogleClient extends BaseClient {
} else if (this.project_id) {
logger.debug('Creating VertexAI client');
return new ChatVertexAI(clientOptions);
} else if (model.includes('1.5')) {
} else if (!EXCLUDED_GENAI_MODELS.test(model)) {
logger.debug('Creating GenAI client');
return new GenAI(this.apiKey).getGenerativeModel(
{
...clientOptions,
model,
},
{ apiVersion: 'v1beta' },
);
return new GenAI(this.apiKey).getGenerativeModel({
...clientOptions,
model,
});
}
logger.debug('Creating Chat Google Generative AI client');
@ -674,7 +672,7 @@ class GoogleClient extends BaseClient {
}
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
if (modelName?.includes('1.5') && !this.project_id) {
if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) {
const client = model;
const requestOptions = {
contents: _payload,
@ -697,7 +695,7 @@ class GoogleClient extends BaseClient {
requestOptions.safetySettings = _payload.safetySettings;
const delay = modelName.includes('flash') ? 8 : 14;
const delay = modelName.includes('flash') ? 8 : 15;
const result = await client.generateContentStream(requestOptions);
for await (const chunk of result.stream) {
const chunkText = chunk.text();
@ -712,7 +710,6 @@ class GoogleClient extends BaseClient {
const stream = await model.stream(messages, {
signal: abortController.signal,
timeout: 7000,
safetySettings: _payload.safetySettings,
});
@ -720,7 +717,7 @@ class GoogleClient extends BaseClient {
if (!this.options.streamRate) {
if (this.isGenerativeModel) {
delay = 12;
delay = 15;
}
if (modelName.includes('flash')) {
delay = 5;
@ -774,8 +771,8 @@ class GoogleClient extends BaseClient {
const messages = this.isTextModel ? _payload.trim() : _messages;
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
if (modelName?.includes('1.5') && !this.project_id) {
logger.debug('Identified titling model as 1.5 version');
if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) {
logger.debug('Identified titling model as GenAI version');
/** @type {GenerativeModel} */
const client = model;
const requestOptions = {