diff --git a/kernel/model/ai.go b/kernel/model/ai.go index 9c5dac1fe..409f8b77a 100644 --- a/kernel/model/ai.go +++ b/kernel/model/ai.go @@ -22,6 +22,7 @@ import ( "github.com/88250/lute/ast" "github.com/88250/lute/parse" + gogpt "github.com/sashabaranov/go-gpt3" "github.com/siyuan-note/siyuan/kernel/treenode" "github.com/siyuan-note/siyuan/kernel/util" ) @@ -47,18 +48,24 @@ func ChatGPTWithAction(ids []string, action string) (ret string) { var cachedContextMsg []string func chatGPT(msg string, cloud bool) (ret string) { - ret, retCtxMsgs := chatGPTContinueWrite(msg, cachedContextMsg, cloud) + ret, retCtxMsgs, err := chatGPTContinueWrite(msg, cachedContextMsg, cloud) + if nil != err { + return + } cachedContextMsg = append(cachedContextMsg, retCtxMsgs...) return } func chatGPTWithAction(msg string, action string, cloud bool) (ret string) { msg = action + ":\n\n" + msg - ret, _ = chatGPTContinueWrite(msg, nil, cloud) + ret, _, err := chatGPTContinueWrite(msg, nil, cloud) + if nil != err { + return + } return } -func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string) { +func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string, err error) { util.PushEndlessProgress("Requesting...") defer util.ClearPushProgress(100) @@ -66,19 +73,19 @@ func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret str contextMsgs = contextMsgs[len(contextMsgs)-7:] } - c := util.NewOpenAIClient() + var gpt GPT + if cloud { + gpt = &CloudGPT{} + } else { + gpt = &OpenAIGPT{c: util.NewOpenAIClient()} + } + buf := &bytes.Buffer{} for i := 0; i < 7; i++ { - var part string - var stop bool - if cloud { - part, stop = CloudChatGPT(msg, contextMsgs) - } else { - part, stop = util.ChatGPT(msg, contextMsgs, c) - } + part, stop, chatErr := gpt.chat(msg, contextMsgs) buf.WriteString(part) - if stop { + if stop || nil != chatErr { break } @@ -138,3 +145,22 @@ func getBlocksContent(ids []string) string { } return buf.String() } + +type GPT interface { + chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) +} + +type OpenAIGPT struct { + c *gogpt.Client +} + +func (gpt *OpenAIGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) { + return util.ChatGPT(msg, contextMsgs, gpt.c) +} + +type CloudGPT struct { +} + +func (gpt *CloudGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) { + return CloudChatGPT(msg, contextMsgs) +} diff --git a/kernel/model/liandi.go b/kernel/model/liandi.go index a7d24fc8e..e5279926e 100644 --- a/kernel/model/liandi.go +++ b/kernel/model/liandi.go @@ -36,7 +36,7 @@ import ( var ErrFailedToConnectCloudServer = errors.New("failed to connect cloud server") -func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool) { +func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool, err error) { if nil == Conf.User { return } @@ -57,7 +57,7 @@ func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool) { requestResult := gulu.Ret.NewResult() request := httpclient.NewCloudRequest30s() - _, err := request. + _, err = request. SetSuccessResult(requestResult). SetCookies(&http.Cookie{Name: "symphony", Value: Conf.User.UserToken}). SetBody(payload). diff --git a/kernel/util/openai.go b/kernel/util/openai.go index b6bb3ce9f..335222879 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -37,7 +37,7 @@ var ( OpenAIAPIBaseURL = "https://api.openai.com/v1" ) -func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool) { +func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool, err error) { var reqMsgs []gogpt.ChatCompletionMessage for _, ctxMsg := range contextMsgs {