From 1c9e0356cc78423ad4f860077be979b6f3d7d694 Mon Sep 17 00:00:00 2001 From: Liang Ding Date: Wed, 8 Mar 2023 12:12:38 +0800 Subject: [PATCH] =?UTF-8?q?:art:=20=E6=8E=A5=E5=85=A5=E4=BA=91=E7=AB=AF?= =?UTF-8?q?=E4=BA=BA=E5=B7=A5=E6=99=BA=E8=83=BD=E6=8E=A5=E5=8F=A3=E5=85=AC?= =?UTF-8?q?=E6=B5=8B=20https://github.com/siyuan-note/siyuan/issues/7601?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- kernel/model/ai.go | 68 +++++++++++++++++++++++++++++++++++++----- kernel/model/liandi.go | 60 +++++++++++++++++++++++++++++++++++++ kernel/util/openai.go | 49 ++---------------------------- 3 files changed, 123 insertions(+), 54 deletions(-) diff --git a/kernel/model/ai.go b/kernel/model/ai.go index df9b33b02..3afe2c89a 100644 --- a/kernel/model/ai.go +++ b/kernel/model/ai.go @@ -18,6 +18,7 @@ package model import ( "bytes" + "strings" "github.com/88250/lute/ast" "github.com/88250/lute/parse" @@ -25,22 +26,75 @@ import ( "github.com/siyuan-note/siyuan/kernel/util" ) -func ChatGPTWithAction(ids []string, action string) (ret string) { - if !isOpenAIAPIEnabled() { +func ChatGPT(msg string) (ret string) { + cloud := IsSubscriber() + if !cloud && !isOpenAIAPIEnabled() { return } + cloud = false + + return chatGPT(msg, cloud) +} + +func ChatGPTWithAction(ids []string, action string) (ret string) { + cloud := IsSubscriber() + if !cloud && !isOpenAIAPIEnabled() { + return + } + + cloud = false + msg := getBlocksContent(ids) - ret = util.ChatGPTWithAction(msg, action) + ret = chatGPTWithAction(msg, action, cloud) return } -func ChatGPT(msg string) (ret string) { - if !isOpenAIAPIEnabled() { - return +var cachedContextMsg []string + +func chatGPT(msg string, cloud bool) (ret string) { + ret, retCtxMsgs := chatGPTContinueWrite(msg, cachedContextMsg, cloud) + cachedContextMsg = append(cachedContextMsg, retCtxMsgs...) + return +} + +func chatGPTWithAction(msg string, action string, cloud bool) (ret string) { + msg = action + ":\n\n" + msg + ret, _ = chatGPTContinueWrite(msg, nil, cloud) + return +} + +func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string) { + util.PushEndlessProgress("Requesting...") + defer util.ClearPushProgress(100) + + if 7 < len(contextMsgs) { + contextMsgs = contextMsgs[len(contextMsgs)-7:] } - return util.ChatGPT(msg) + c := util.NewOpenAIClient() + buf := &bytes.Buffer{} + for i := 0; i < 7; i++ { + var part string + var stop bool + if cloud { + part, stop = CloudChatGPT(msg, contextMsgs) + } else { + part, stop = util.ChatGPT(msg, contextMsgs, c) + } + buf.WriteString(part) + + if stop { + break + } + + util.PushEndlessProgress("Continue requesting...") + } + + ret = buf.String() + ret = strings.TrimSpace(ret) + retContextMsgs = append(retContextMsgs, msg, ret) + return } func isOpenAIAPIEnabled() bool { diff --git a/kernel/model/liandi.go b/kernel/model/liandi.go index 1b471e9b3..a7d24fc8e 100644 --- a/kernel/model/liandi.go +++ b/kernel/model/liandi.go @@ -36,6 +36,66 @@ import ( var ErrFailedToConnectCloudServer = errors.New("failed to connect cloud server") +func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool) { + if nil == Conf.User { + return + } + + payload := map[string]interface{}{} + var messages []map[string]interface{} + for _, contextMsg := range contextMsgs { + messages = append(messages, map[string]interface{}{ + "role": "user", + "content": contextMsg, + }) + } + messages = append(messages, map[string]interface{}{ + "role": "user", + "content": msg, + }) + payload["messages"] = messages + + requestResult := gulu.Ret.NewResult() + request := httpclient.NewCloudRequest30s() + _, err := request. + SetSuccessResult(requestResult). + SetCookies(&http.Cookie{Name: "symphony", Value: Conf.User.UserToken}). + SetBody(payload). + Post(util.AliyunServer + "/apis/siyuan/ai/chatGPT") + if nil != err { + logging.LogErrorf("chat gpt failed: %s", err) + err = ErrFailedToConnectCloudServer + return + } + if 0 != requestResult.Code { + err = errors.New(requestResult.Msg) + stop = true + return + } + + data := requestResult.Data.(map[string]interface{}) + choices := data["choices"].([]interface{}) + if 1 > len(choices) { + stop = true + return + } + choice := choices[0].(map[string]interface{}) + message := choice["message"].(map[string]interface{}) + ret = message["content"].(string) + + if nil != choice["finish_reason"] { + finishReason := choice["finish_reason"].(string) + if "length" == finishReason { + stop = false + } else { + stop = true + } + } else { + stop = true + } + return +} + func StartFreeTrial() (err error) { if nil == Conf.User { return errors.New(Conf.Language(31)) diff --git a/kernel/util/openai.go b/kernel/util/openai.go index 6a13690b6..247cb8e68 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -17,7 +17,6 @@ package util import ( - "bytes" "context" "net/http" "net/url" @@ -37,52 +36,8 @@ var ( OpenAIAPIMaxTokens = 0 ) -var cachedContextMsg []string - -func ChatGPT(msg string) (ret string) { - ret, retCtxMsgs := ChatGPTContinueWrite(msg, cachedContextMsg) - cachedContextMsg = append(cachedContextMsg, retCtxMsgs...) - return -} - -func ChatGPTWithAction(msg string, action string) (ret string) { - msg = action + ":\n\n" + msg - ret, _ = ChatGPTContinueWrite(msg, nil) - return -} - -func ChatGPTContinueWrite(msg string, contextMsgs []string) (ret string, retContextMsgs []string) { - if "" == OpenAIAPIKey { - return - } - - PushEndlessProgress("Requesting...") - defer ClearPushProgress(100) - - c := newOpenAIClient() - buf := &bytes.Buffer{} - for i := 0; i < 7; i++ { - part, stop := chatGPT(msg, contextMsgs, c) - buf.WriteString(part) - - if stop { - break - } - - PushEndlessProgress("Continue requesting...") - } - - ret = buf.String() - ret = strings.TrimSpace(ret) - retContextMsgs = append(retContextMsgs, msg, ret) - return -} - -func chatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool) { +func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool) { var reqMsgs []gogpt.ChatCompletionMessage - if 7 < len(contextMsgs) { - contextMsgs = contextMsgs[len(contextMsgs)-7:] - } for _, ctxMsg := range contextMsgs { reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{ @@ -129,7 +84,7 @@ func chatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, sto return } -func newOpenAIClient() *gogpt.Client { +func NewOpenAIClient() *gogpt.Client { config := gogpt.DefaultConfig(OpenAIAPIKey) if "" != OpenAIAPIProxy { proxyUrl, err := url.Parse(OpenAIAPIProxy)