diff --git a/kernel/api/ai.go b/kernel/api/ai.go index b0a49824c..02dece192 100644 --- a/kernel/api/ai.go +++ b/kernel/api/ai.go @@ -38,7 +38,7 @@ func chatGPT(c *gin.Context) { ret.Data = model.ChatGPT(msg) } -func chatGPTContinueWriteBlocks(c *gin.Context) { +func chatGPTWithAction(c *gin.Context) { ret := gulu.Ret.NewResult() defer c.JSON(http.StatusOK, ret) @@ -52,74 +52,6 @@ func chatGPTContinueWriteBlocks(c *gin.Context) { for _, id := range idsArg { ids = append(ids, id.(string)) } - ret.Data = model.ChatGPTContinueWriteBlocks(ids) -} - -func chatGPTTranslate(c *gin.Context) { - ret := gulu.Ret.NewResult() - defer c.JSON(http.StatusOK, ret) - - arg, ok := util.JsonArg(c, ret) - if !ok { - return - } - - idsArg := arg["ids"].([]interface{}) - var ids []string - for _, id := range idsArg { - ids = append(ids, id.(string)) - } - lang := arg["lang"].(string) - ret.Data = model.ChatGPTTranslate(ids, lang) -} - -func chatGPTSummary(c *gin.Context) { - ret := gulu.Ret.NewResult() - defer c.JSON(http.StatusOK, ret) - - arg, ok := util.JsonArg(c, ret) - if !ok { - return - } - - idsArg := arg["ids"].([]interface{}) - var ids []string - for _, id := range idsArg { - ids = append(ids, id.(string)) - } - ret.Data = model.ChatGPTSummary(ids) -} - -func chatGPTBrainStorm(c *gin.Context) { - ret := gulu.Ret.NewResult() - defer c.JSON(http.StatusOK, ret) - - arg, ok := util.JsonArg(c, ret) - if !ok { - return - } - - idsArg := arg["ids"].([]interface{}) - var ids []string - for _, id := range idsArg { - ids = append(ids, id.(string)) - } - ret.Data = model.ChatGPTBrainStorm(ids) -} - -func chatGPTFixGrammarSpell(c *gin.Context) { - ret := gulu.Ret.NewResult() - defer c.JSON(http.StatusOK, ret) - - arg, ok := util.JsonArg(c, ret) - if !ok { - return - } - - idsArg := arg["ids"].([]interface{}) - var ids []string - for _, id := range idsArg { - ids = append(ids, id.(string)) - } - ret.Data = model.ChatGPTFixGrammarSpell(ids) + action := arg["action"].(string) + ret.Data = model.ChatGPTWithAction(ids, action) } diff --git a/kernel/api/router.go b/kernel/api/router.go index 9c10c5883..8846ca349 100644 --- a/kernel/api/router.go +++ b/kernel/api/router.go @@ -329,9 +329,5 @@ func ServeAPI(ginServer *gin.Engine) { ginServer.Handle("POST", "/api/av/renderAttributeView", model.CheckAuth, renderAttributeView) ginServer.Handle("POST", "/api/ai/chatGPT", model.CheckAuth, chatGPT) - ginServer.Handle("POST", "/api/ai/chatGPTContinueWriteBlocks", model.CheckAuth, chatGPTContinueWriteBlocks) - ginServer.Handle("POST", "/api/ai/chatGPTTranslate", model.CheckAuth, chatGPTTranslate) - ginServer.Handle("POST", "/api/ai/chatGPTSummary", model.CheckAuth, chatGPTSummary) - ginServer.Handle("POST", "/api/ai/chatGPTBrainStorm", model.CheckAuth, chatGPTBrainStorm) - ginServer.Handle("POST", "/api/ai/chatGPTFixGrammarSpell", model.CheckAuth, chatGPTFixGrammarSpell) + ginServer.Handle("POST", "/api/ai/chatGPTWithAction", model.CheckAuth, chatGPTWithAction) } diff --git a/kernel/model/ai.go b/kernel/model/ai.go index dabe796f7..a2f759263 100644 --- a/kernel/model/ai.go +++ b/kernel/model/ai.go @@ -25,53 +25,13 @@ import ( "github.com/siyuan-note/siyuan/kernel/util" ) -func ChatGPTFixGrammarSpell(ids []string) (ret string) { +func ChatGPTWithAction(ids []string, action string) (ret string) { if !isOpenAIAPIEnabled() { return } msg := getBlocksContent(ids) - ret = util.ChatGPTFixGrammarSpell(msg, Conf.Lang) - return -} - -func ChatGPTBrainStorm(ids []string) (ret string) { - if !isOpenAIAPIEnabled() { - return - } - - msg := getBlocksContent(ids) - ret = util.ChatGPTBrainStorm(msg, Conf.Lang) - return -} - -func ChatGPTSummary(ids []string) (ret string) { - if !isOpenAIAPIEnabled() { - return - } - - msg := getBlocksContent(ids) - ret = util.ChatGPTSummary(msg, Conf.Lang) - return -} - -func ChatGPTTranslate(ids []string, lang string) (ret string) { - if !isOpenAIAPIEnabled() { - return - } - - msg := getBlocksContent(ids) - ret = util.ChatGPTTranslate(msg, lang) - return -} - -func ChatGPTContinueWriteBlocks(ids []string) (ret string) { - if !isOpenAIAPIEnabled() { - return - } - - msg := getBlocksContent(ids) - ret, _ = util.ChatGPTContinueWrite(msg, nil) + ret = util.ChatGPTWithAction(msg, action, Conf.Lang) return } diff --git a/kernel/util/openai.go b/kernel/util/openai.go index 17775ea11..2ee9bddfb 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -45,26 +45,11 @@ func ChatGPT(msg string) (ret string) { return } -func ChatGPTTranslate(msg string, lang string) (ret string) { - msg = "Translate to " + lang + ":\n" + msg - ret, _ = ChatGPTContinueWrite(msg, nil) - return -} - -func ChatGPTSummary(msg string, lang string) (ret string) { - msg = "Summarized as follows, the result is in {" + lang + "}:\n" + msg - ret, _ = ChatGPTContinueWrite(msg, nil) - return -} - -func ChatGPTBrainStorm(msg string, lang string) (ret string) { - msg = "Brainstorm ideas as follows, the result is in {" + lang + "}:\n" + msg - ret, _ = ChatGPTContinueWrite(msg, nil) - return -} - -func ChatGPTFixGrammarSpell(msg string, lang string) (ret string) { - msg = "Fix grammar and spelling as follows, the result is in {" + lang + "}:\n" + msg +func ChatGPTWithAction(msg string, action string, lang string) (ret string) { + prompt := "{action} as follows, the result is in {lang}:\n" + prompt = strings.Replace(prompt, "{action}", action, -1) + prompt = strings.Replace(prompt, "{lang}", lang, -1) + msg = prompt + msg ret, _ = ChatGPTContinueWrite(msg, nil) return }