From 20af6b0b7aad9b40b4c52b5f742b91672f474831 Mon Sep 17 00:00:00 2001 From: Liang Ding Date: Wed, 8 Mar 2023 21:28:29 +0800 Subject: [PATCH] =?UTF-8?q?:art:=20=E5=8A=A0=E5=85=A5=E9=92=88=E5=AF=B9?= =?UTF-8?q?=E5=86=85=E5=AE=B9=E5=9D=97=E7=9A=84=E4=BA=BA=E5=B7=A5=E6=99=BA?= =?UTF-8?q?=E8=83=BD=E8=BE=85=E5=8A=A9=E6=94=AF=E6=8C=81=20https://github.?= =?UTF-8?q?com/siyuan-note/siyuan/issues/7566?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- kernel/model/ai.go | 50 ++++++++++++++++++++++++++++++++---------- kernel/model/liandi.go | 4 ++-- kernel/util/openai.go | 2 +- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/kernel/model/ai.go b/kernel/model/ai.go index 9c5dac1fe..409f8b77a 100644 --- a/kernel/model/ai.go +++ b/kernel/model/ai.go @@ -22,6 +22,7 @@ import ( "github.com/88250/lute/ast" "github.com/88250/lute/parse" + gogpt "github.com/sashabaranov/go-gpt3" "github.com/siyuan-note/siyuan/kernel/treenode" "github.com/siyuan-note/siyuan/kernel/util" ) @@ -47,18 +48,24 @@ func ChatGPTWithAction(ids []string, action string) (ret string) { var cachedContextMsg []string func chatGPT(msg string, cloud bool) (ret string) { - ret, retCtxMsgs := chatGPTContinueWrite(msg, cachedContextMsg, cloud) + ret, retCtxMsgs, err := chatGPTContinueWrite(msg, cachedContextMsg, cloud) + if nil != err { + return + } cachedContextMsg = append(cachedContextMsg, retCtxMsgs...) return } func chatGPTWithAction(msg string, action string, cloud bool) (ret string) { msg = action + ":\n\n" + msg - ret, _ = chatGPTContinueWrite(msg, nil, cloud) + ret, _, err := chatGPTContinueWrite(msg, nil, cloud) + if nil != err { + return + } return } -func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string) { +func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string, err error) { util.PushEndlessProgress("Requesting...") defer util.ClearPushProgress(100) @@ -66,19 +73,19 @@ func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret str contextMsgs = contextMsgs[len(contextMsgs)-7:] } - c := util.NewOpenAIClient() + var gpt GPT + if cloud { + gpt = &CloudGPT{} + } else { + gpt = &OpenAIGPT{c: util.NewOpenAIClient()} + } + buf := &bytes.Buffer{} for i := 0; i < 7; i++ { - var part string - var stop bool - if cloud { - part, stop = CloudChatGPT(msg, contextMsgs) - } else { - part, stop = util.ChatGPT(msg, contextMsgs, c) - } + part, stop, chatErr := gpt.chat(msg, contextMsgs) buf.WriteString(part) - if stop { + if stop || nil != chatErr { break } @@ -138,3 +145,22 @@ func getBlocksContent(ids []string) string { } return buf.String() } + +type GPT interface { + chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) +} + +type OpenAIGPT struct { + c *gogpt.Client +} + +func (gpt *OpenAIGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) { + return util.ChatGPT(msg, contextMsgs, gpt.c) +} + +type CloudGPT struct { +} + +func (gpt *CloudGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) { + return CloudChatGPT(msg, contextMsgs) +} diff --git a/kernel/model/liandi.go b/kernel/model/liandi.go index a7d24fc8e..e5279926e 100644 --- a/kernel/model/liandi.go +++ b/kernel/model/liandi.go @@ -36,7 +36,7 @@ import ( var ErrFailedToConnectCloudServer = errors.New("failed to connect cloud server") -func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool) { +func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool, err error) { if nil == Conf.User { return } @@ -57,7 +57,7 @@ func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool) { requestResult := gulu.Ret.NewResult() request := httpclient.NewCloudRequest30s() - _, err := request. + _, err = request. SetSuccessResult(requestResult). SetCookies(&http.Cookie{Name: "symphony", Value: Conf.User.UserToken}). SetBody(payload). diff --git a/kernel/util/openai.go b/kernel/util/openai.go index b6bb3ce9f..335222879 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -37,7 +37,7 @@ var ( OpenAIAPIBaseURL = "https://api.openai.com/v1" ) -func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool) { +func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool, err error) { var reqMsgs []gogpt.ChatCompletionMessage for _, ctxMsg := range contextMsgs {