🎨 fix: Optimize StableDiffusion API Tool and Fix for Assistants Usage (#2253)

* chore: update docs

* fix(StableDiffusion): optimize API responses and file handling, return expected metadata for Assistants endpoint
This commit is contained in:
Danny Avila 2024-03-30 20:09:59 -04:00 committed by GitHub
parent 56ea0f9ae7
commit bb8a40dd98
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 113 additions and 52 deletions

View file

@ -4,14 +4,27 @@ const { z } = require('zod');
const path = require('path');
const axios = require('axios');
const sharp = require('sharp');
const { v4: uuidv4 } = require('uuid');
const { StructuredTool } = require('langchain/tools');
const { FileContext } = require('librechat-data-provider');
const paths = require('~/config/paths');
const { logger } = require('~/config');
class StableDiffusionAPI extends StructuredTool {
constructor(fields) {
super();
/* Used to initialize the Tool without necessary variables. */
/** @type {string} User ID */
this.userId = fields.userId;
/** @type {Express.Request | undefined} Express Request object, only provided by ToolService */
this.req = fields.req;
/** @type {boolean} Used to initialize the Tool without necessary variables. */
this.override = fields.override ?? false;
/** @type {boolean} Necessary for output to contain all image metadata. */
this.returnMetadata = fields.returnMetadata ?? false;
if (fields.uploadImageBuffer) {
/** @type {uploadImageBuffer} Necessary for output to contain all image metadata. */
this.uploadImageBuffer = fields.uploadImageBuffer.bind(this);
}
this.name = 'stable-diffusion';
this.url = fields.SD_WEBUI_URL || this.getServerURL();
@ -47,7 +60,7 @@ class StableDiffusionAPI extends StructuredTool {
getMarkdownImageUrl(imageName) {
const imageUrl = path
.join(this.relativeImageUrl, imageName)
.join(this.relativePath, this.userId, imageName)
.replace(/\\/g, '/')
.replace('public/', '');
return `![generated image](/${imageUrl})`;
@ -73,46 +86,67 @@ class StableDiffusionAPI extends StructuredTool {
width: 1024,
height: 1024,
};
const response = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
const image = response.data.images[0];
const pngPayload = { image: `data:image/png;base64,${image}` };
const response2 = await axios.post(`${url}/sdapi/v1/png-info`, pngPayload);
const info = response2.data.info;
const generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
const image = generationResponse.data.images[0];
// Generate unique name
const imageName = `${Date.now()}.png`;
this.outputPath = path.resolve(
__dirname,
'..',
'..',
'..',
'..',
'..',
'client',
'public',
'images',
);
const appRoot = path.resolve(__dirname, '..', '..', '..', '..', '..', 'client');
this.relativeImageUrl = path.relative(appRoot, this.outputPath);
/** @type {{ height: number, width: number, seed: number, infotexts: string[] }} */
let info = {};
try {
info = JSON.parse(generationResponse.data.info);
} catch (error) {
logger.error('[StableDiffusion] Error while getting image metadata:', error);
}
// Check if directory exists, if not create it
if (!fs.existsSync(this.outputPath)) {
fs.mkdirSync(this.outputPath, { recursive: true });
const file_id = uuidv4();
const imageName = `${file_id}.png`;
const { imageOutput: imageOutputPath, clientPath } = paths;
const filepath = path.join(imageOutputPath, this.userId, imageName);
this.relativePath = path.relative(clientPath, imageOutputPath);
if (!fs.existsSync(imageOutputPath)) {
fs.mkdirSync(imageOutputPath, { recursive: true });
}
try {
const buffer = Buffer.from(image.split(',', 1)[0], 'base64');
if (this.returnMetadata && this.uploadImageBuffer && this.req) {
const file = await this.uploadImageBuffer({
req: this.req,
context: FileContext.image_generation,
resize: false,
metadata: {
buffer,
height: info.height,
width: info.width,
bytes: Buffer.byteLength(buffer),
filename: imageName,
type: 'image/png',
file_id,
},
});
const generationInfo = info.infotexts[0].split('\n').pop();
return {
...file,
prompt,
metadata: {
negative_prompt,
seed: info.seed,
info: generationInfo,
},
};
}
await sharp(buffer)
.withMetadata({
iptcpng: {
parameters: info,
parameters: info.infotexts[0],
},
})
.toFile(this.outputPath + '/' + imageName);
.toFile(filepath);
this.result = this.getMarkdownImageUrl(imageName);
} catch (error) {
logger.error('[StableDiffusion] Error while saving the image:', error);
// this.result = theImageUrl;
}
return this.result;