From 1c9e0356cc78423ad4f860077be979b6f3d7d694 Mon Sep 17 00:00:00 2001 From: Liang Ding Date: Wed, 8 Mar 2023 12:12:38 +0800 Subject: [PATCH 1/3] =?UTF-8?q?:art:=20=E6=8E=A5=E5=85=A5=E4=BA=91?= =?UTF-8?q?=E7=AB=AF=E4=BA=BA=E5=B7=A5=E6=99=BA=E8=83=BD=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E5=85=AC=E6=B5=8B=20https://github.com/siyuan-note/siyuan/issu?= =?UTF-8?q?es/7601?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- kernel/model/ai.go | 68 +++++++++++++++++++++++++++++++++++++----- kernel/model/liandi.go | 60 +++++++++++++++++++++++++++++++++++++ kernel/util/openai.go | 49 ++---------------------------- 3 files changed, 123 insertions(+), 54 deletions(-) diff --git a/kernel/model/ai.go b/kernel/model/ai.go index df9b33b02..3afe2c89a 100644 --- a/kernel/model/ai.go +++ b/kernel/model/ai.go @@ -18,6 +18,7 @@ package model import ( "bytes" + "strings" "github.com/88250/lute/ast" "github.com/88250/lute/parse" @@ -25,22 +26,75 @@ import ( "github.com/siyuan-note/siyuan/kernel/util" ) -func ChatGPTWithAction(ids []string, action string) (ret string) { - if !isOpenAIAPIEnabled() { +func ChatGPT(msg string) (ret string) { + cloud := IsSubscriber() + if !cloud && !isOpenAIAPIEnabled() { return } + cloud = false + + return chatGPT(msg, cloud) +} + +func ChatGPTWithAction(ids []string, action string) (ret string) { + cloud := IsSubscriber() + if !cloud && !isOpenAIAPIEnabled() { + return + } + + cloud = false + msg := getBlocksContent(ids) - ret = util.ChatGPTWithAction(msg, action) + ret = chatGPTWithAction(msg, action, cloud) return } -func ChatGPT(msg string) (ret string) { - if !isOpenAIAPIEnabled() { - return +var cachedContextMsg []string + +func chatGPT(msg string, cloud bool) (ret string) { + ret, retCtxMsgs := chatGPTContinueWrite(msg, cachedContextMsg, cloud) + cachedContextMsg = append(cachedContextMsg, retCtxMsgs...) + return +} + +func chatGPTWithAction(msg string, action string, cloud bool) (ret string) { + msg = action + ":\n\n" + msg + ret, _ = chatGPTContinueWrite(msg, nil, cloud) + return +} + +func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string) { + util.PushEndlessProgress("Requesting...") + defer util.ClearPushProgress(100) + + if 7 < len(contextMsgs) { + contextMsgs = contextMsgs[len(contextMsgs)-7:] } - return util.ChatGPT(msg) + c := util.NewOpenAIClient() + buf := &bytes.Buffer{} + for i := 0; i < 7; i++ { + var part string + var stop bool + if cloud { + part, stop = CloudChatGPT(msg, contextMsgs) + } else { + part, stop = util.ChatGPT(msg, contextMsgs, c) + } + buf.WriteString(part) + + if stop { + break + } + + util.PushEndlessProgress("Continue requesting...") + } + + ret = buf.String() + ret = strings.TrimSpace(ret) + retContextMsgs = append(retContextMsgs, msg, ret) + return } func isOpenAIAPIEnabled() bool { diff --git a/kernel/model/liandi.go b/kernel/model/liandi.go index 1b471e9b3..a7d24fc8e 100644 --- a/kernel/model/liandi.go +++ b/kernel/model/liandi.go @@ -36,6 +36,66 @@ import ( var ErrFailedToConnectCloudServer = errors.New("failed to connect cloud server") +func CloudChatGPT(msg string, contextMsgs []string) (ret string, stop bool) { + if nil == Conf.User { + return + } + + payload := map[string]interface{}{} + var messages []map[string]interface{} + for _, contextMsg := range contextMsgs { + messages = append(messages, map[string]interface{}{ + "role": "user", + "content": contextMsg, + }) + } + messages = append(messages, map[string]interface{}{ + "role": "user", + "content": msg, + }) + payload["messages"] = messages + + requestResult := gulu.Ret.NewResult() + request := httpclient.NewCloudRequest30s() + _, err := request. + SetSuccessResult(requestResult). + SetCookies(&http.Cookie{Name: "symphony", Value: Conf.User.UserToken}). + SetBody(payload). + Post(util.AliyunServer + "/apis/siyuan/ai/chatGPT") + if nil != err { + logging.LogErrorf("chat gpt failed: %s", err) + err = ErrFailedToConnectCloudServer + return + } + if 0 != requestResult.Code { + err = errors.New(requestResult.Msg) + stop = true + return + } + + data := requestResult.Data.(map[string]interface{}) + choices := data["choices"].([]interface{}) + if 1 > len(choices) { + stop = true + return + } + choice := choices[0].(map[string]interface{}) + message := choice["message"].(map[string]interface{}) + ret = message["content"].(string) + + if nil != choice["finish_reason"] { + finishReason := choice["finish_reason"].(string) + if "length" == finishReason { + stop = false + } else { + stop = true + } + } else { + stop = true + } + return +} + func StartFreeTrial() (err error) { if nil == Conf.User { return errors.New(Conf.Language(31)) diff --git a/kernel/util/openai.go b/kernel/util/openai.go index 6a13690b6..247cb8e68 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -17,7 +17,6 @@ package util import ( - "bytes" "context" "net/http" "net/url" @@ -37,52 +36,8 @@ var ( OpenAIAPIMaxTokens = 0 ) -var cachedContextMsg []string - -func ChatGPT(msg string) (ret string) { - ret, retCtxMsgs := ChatGPTContinueWrite(msg, cachedContextMsg) - cachedContextMsg = append(cachedContextMsg, retCtxMsgs...) - return -} - -func ChatGPTWithAction(msg string, action string) (ret string) { - msg = action + ":\n\n" + msg - ret, _ = ChatGPTContinueWrite(msg, nil) - return -} - -func ChatGPTContinueWrite(msg string, contextMsgs []string) (ret string, retContextMsgs []string) { - if "" == OpenAIAPIKey { - return - } - - PushEndlessProgress("Requesting...") - defer ClearPushProgress(100) - - c := newOpenAIClient() - buf := &bytes.Buffer{} - for i := 0; i < 7; i++ { - part, stop := chatGPT(msg, contextMsgs, c) - buf.WriteString(part) - - if stop { - break - } - - PushEndlessProgress("Continue requesting...") - } - - ret = buf.String() - ret = strings.TrimSpace(ret) - retContextMsgs = append(retContextMsgs, msg, ret) - return -} - -func chatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool) { +func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool) { var reqMsgs []gogpt.ChatCompletionMessage - if 7 < len(contextMsgs) { - contextMsgs = contextMsgs[len(contextMsgs)-7:] - } for _, ctxMsg := range contextMsgs { reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{ @@ -129,7 +84,7 @@ func chatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, sto return } -func newOpenAIClient() *gogpt.Client { +func NewOpenAIClient() *gogpt.Client { config := gogpt.DefaultConfig(OpenAIAPIKey) if "" != OpenAIAPIProxy { proxyUrl, err := url.Parse(OpenAIAPIProxy) From a8129de5f5dce8fe6fe460889b88f1d9d17b720a Mon Sep 17 00:00:00 2001 From: Liang Ding Date: Wed, 8 Mar 2023 20:00:17 +0800 Subject: [PATCH 2/3] =?UTF-8?q?:art:=20OpenAI=20API=20=E6=8E=A5=E5=85=A5?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E7=8E=AF=E5=A2=83=E5=8F=98=E9=87=8F=20`SIYUA?= =?UTF-8?q?N=5FOPENAI=5FAPI=5FBASE=5FURL`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- kernel/model/ai.go | 14 ++++---------- kernel/util/openai.go | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/kernel/model/ai.go b/kernel/model/ai.go index 3afe2c89a..9c5dac1fe 100644 --- a/kernel/model/ai.go +++ b/kernel/model/ai.go @@ -27,26 +27,20 @@ import ( ) func ChatGPT(msg string) (ret string) { - cloud := IsSubscriber() - if !cloud && !isOpenAIAPIEnabled() { + if !isOpenAIAPIEnabled() { return } - cloud = false - - return chatGPT(msg, cloud) + return chatGPT(msg, false) } func ChatGPTWithAction(ids []string, action string) (ret string) { - cloud := IsSubscriber() - if !cloud && !isOpenAIAPIEnabled() { + if !isOpenAIAPIEnabled() { return } - cloud = false - msg := getBlocksContent(ids) - ret = chatGPTWithAction(msg, action, cloud) + ret = chatGPTWithAction(msg, action, false) return } diff --git a/kernel/util/openai.go b/kernel/util/openai.go index 247cb8e68..b6bb3ce9f 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -34,6 +34,7 @@ var ( OpenAIAPITimeout = 30 * time.Second OpenAIAPIProxy = "" OpenAIAPIMaxTokens = 0 + OpenAIAPIBaseURL = "https://api.openai.com/v1" ) func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool) { @@ -94,6 +95,8 @@ func NewOpenAIClient() *gogpt.Client { config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}} } } + + config.BaseURL = OpenAIAPIBaseURL return gogpt.NewClientWithConfig(config) } @@ -124,5 +127,15 @@ func initOpenAI() { } } - logging.LogInfof("OpenAI API enabled [maxTokens=%d, timeout=%ds, proxy=%s]", OpenAIAPIMaxTokens, int(OpenAIAPITimeout.Seconds()), OpenAIAPIProxy) + baseURL := os.Getenv("SIYUAN_OPENAI_API_BASE_URL") + if "" != baseURL { + OpenAIAPIBaseURL = baseURL + } + + logging.LogInfof("OpenAI API enabled\n"+ + " baseURL=%s\n"+ + " timeout=%ds\n"+ + " proxy=%s\n"+ + " maxTokens=%d", + OpenAIAPIBaseURL, int(OpenAIAPITimeout.Seconds()), OpenAIAPIProxy, OpenAIAPIMaxTokens) } From 63dcb41b353e360fff27bbad1cf3163da203c187 Mon Sep 17 00:00:00 2001 From: Liang Ding Date: Wed, 8 Mar 2023 20:00:46 +0800 Subject: [PATCH 3/3] =?UTF-8?q?:pencil2:=20=E4=BF=AE=E5=A4=8D=E8=A5=BF?= =?UTF-8?q?=E7=8F=AD=E7=89=99=E8=AF=AD=E5=92=8C=E6=B3=95=E8=AF=AD=E5=A4=9A?= =?UTF-8?q?=E8=AF=AD=E8=A8=80=E6=96=87=E4=BB=B6=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/appearance/langs/es_ES.json | 2 +- app/appearance/langs/fr_FR.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/app/appearance/langs/es_ES.json b/app/appearance/langs/es_ES.json index cf8531e53..505b7beef 100644 --- a/app/appearance/langs/es_ES.json +++ b/app/appearance/langs/es_ES.json @@ -909,7 +909,7 @@ "task.history.database.index.commit": "Ejecutar la confirmación del índice de la base de datos del historial", "task.database.index.embedBlock": "Ejecutar bloque de incrustación de índice de base de datos", "task.reload.ui": "IU de recarga de tareas", - "task.upgrade.userGuide": "Ejecutar la guía de usuario de actualización", + "task.upgrade.userGuide": "Ejecutar la guía de usuario de actualización" }, "_trayMenu": { "showWindow": "Mostrar ventana", diff --git a/app/appearance/langs/fr_FR.json b/app/appearance/langs/fr_FR.json index 5ddd47df3..ed9611b98 100644 --- a/app/appearance/langs/fr_FR.json +++ b/app/appearance/langs/fr_FR.json @@ -909,7 +909,7 @@ "task.history.database.index.commit": "Effectuer la validation de l'index de la base de données d'historique", "task.database.index.embedBlock": "Exécuter le bloc d'intégration d'index de base de données", "task.reload.ui": "Interface utilisateur de rechargement de tâche", - "task.upgrade.userGuide": "Mise à niveau de la tâche de guide utilisateur", + "task.upgrade.userGuide": "Mise à niveau de la tâche de guide utilisateur" }, "_trayMenu": { "showWindow": "Afficher la fenêtre principale",