diff --git a/kernel/api/ai.go b/kernel/api/ai.go index 13717e233..b0a49824c 100644 --- a/kernel/api/ai.go +++ b/kernel/api/ai.go @@ -89,3 +89,37 @@ func chatGPTSummary(c *gin.Context) { } ret.Data = model.ChatGPTSummary(ids) } + +func chatGPTBrainStorm(c *gin.Context) { + ret := gulu.Ret.NewResult() + defer c.JSON(http.StatusOK, ret) + + arg, ok := util.JsonArg(c, ret) + if !ok { + return + } + + idsArg := arg["ids"].([]interface{}) + var ids []string + for _, id := range idsArg { + ids = append(ids, id.(string)) + } + ret.Data = model.ChatGPTBrainStorm(ids) +} + +func chatGPTFixGrammarSpell(c *gin.Context) { + ret := gulu.Ret.NewResult() + defer c.JSON(http.StatusOK, ret) + + arg, ok := util.JsonArg(c, ret) + if !ok { + return + } + + idsArg := arg["ids"].([]interface{}) + var ids []string + for _, id := range idsArg { + ids = append(ids, id.(string)) + } + ret.Data = model.ChatGPTFixGrammarSpell(ids) +} diff --git a/kernel/api/router.go b/kernel/api/router.go index de415a515..9c10c5883 100644 --- a/kernel/api/router.go +++ b/kernel/api/router.go @@ -332,4 +332,6 @@ func ServeAPI(ginServer *gin.Engine) { ginServer.Handle("POST", "/api/ai/chatGPTContinueWriteBlocks", model.CheckAuth, chatGPTContinueWriteBlocks) ginServer.Handle("POST", "/api/ai/chatGPTTranslate", model.CheckAuth, chatGPTTranslate) ginServer.Handle("POST", "/api/ai/chatGPTSummary", model.CheckAuth, chatGPTSummary) + ginServer.Handle("POST", "/api/ai/chatGPTBrainStorm", model.CheckAuth, chatGPTBrainStorm) + ginServer.Handle("POST", "/api/ai/chatGPTFixGrammarSpell", model.CheckAuth, chatGPTFixGrammarSpell) } diff --git a/kernel/model/ai.go b/kernel/model/ai.go index bc48dc5a9..dabe796f7 100644 --- a/kernel/model/ai.go +++ b/kernel/model/ai.go @@ -25,6 +25,26 @@ import ( "github.com/siyuan-note/siyuan/kernel/util" ) +func ChatGPTFixGrammarSpell(ids []string) (ret string) { + if !isOpenAIAPIEnabled() { + return + } + + msg := getBlocksContent(ids) + ret = util.ChatGPTFixGrammarSpell(msg, Conf.Lang) + return +} + +func ChatGPTBrainStorm(ids []string) (ret string) { + if !isOpenAIAPIEnabled() { + return + } + + msg := getBlocksContent(ids) + ret = util.ChatGPTBrainStorm(msg, Conf.Lang) + return +} + func ChatGPTSummary(ids []string) (ret string) { if !isOpenAIAPIEnabled() { return diff --git a/kernel/util/openai.go b/kernel/util/openai.go index ca98d6528..17775ea11 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -57,6 +57,18 @@ func ChatGPTSummary(msg string, lang string) (ret string) { return } +func ChatGPTBrainStorm(msg string, lang string) (ret string) { + msg = "Brainstorm ideas as follows, the result is in {" + lang + "}:\n" + msg + ret, _ = ChatGPTContinueWrite(msg, nil) + return +} + +func ChatGPTFixGrammarSpell(msg string, lang string) (ret string) { + msg = "Fix grammar and spelling as follows, the result is in {" + lang + "}:\n" + msg + ret, _ = ChatGPTContinueWrite(msg, nil) + return +} + func ChatGPTContinueWrite(msg string, contextMsgs []string) (ret string, retContextMsgs []string) { if "" == OpenAIAPIKey { return