diff --git a/kernel/util/openai.go b/kernel/util/openai.go index 7478f8138..ada939c33 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -40,7 +40,15 @@ var ( OpenAIAPIMaxTokens = 0 ) +var cachedContextMsg []string + func ChatGPT(msg string) (ret string) { + ret, retCtxMsgs := ChatGPTContinueWrite(msg, cachedContextMsg) + cachedContextMsg = append(cachedContextMsg, retCtxMsgs...) + return +} + +func ChatGPTContinueWrite(msg string, contextMsgs []string) (ret string, retContextMsgs []string) { if "" == OpenAIAPIKey { return } @@ -48,30 +56,31 @@ func ChatGPT(msg string) (ret string) { PushEndlessProgress("Requesting...") defer ClearPushProgress(100) - config := gogpt.DefaultConfig(OpenAIAPIKey) - if "" != OpenAIAPIProxy { - proxyUrl, err := url.Parse(OpenAIAPIProxy) - if nil != err { - logging.LogErrorf("OpenAI API proxy error: %v", err) - } else { - config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}} - } + c := newOpenAIClient() + + var reqMsgs []gogpt.ChatCompletionMessage + if 7 < len(contextMsgs) { + contextMsgs = contextMsgs[len(contextMsgs)-7:] } - c := gogpt.NewClientWithConfig(config) - ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout) - defer cancel() + for _, ctxMsg := range contextMsgs { + reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{ + Role: "user", + Content: ctxMsg, + }) + } + reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{ + Role: "user", + Content: msg, + }) + req := gogpt.ChatCompletionRequest{ Model: gogpt.GPT3Dot5Turbo, MaxTokens: OpenAIAPIMaxTokens, - Messages: []gogpt.ChatCompletionMessage{ - { - Role: "user", - Content: msg, - }, - }, + Messages: reqMsgs, } - + ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout) + defer cancel() stream, err := c.CreateChatCompletionStream(ctx, req) if nil != err { logging.LogErrorf("create chat completion stream failed: %s", err) @@ -100,9 +109,23 @@ func ChatGPT(msg string) (ret string) { ret = buf.String() ret = strings.TrimSpace(ret) + retContextMsgs = append(retContextMsgs, msg, ret) return } +func newOpenAIClient() *gogpt.Client { + config := gogpt.DefaultConfig(OpenAIAPIKey) + if "" != OpenAIAPIProxy { + proxyUrl, err := url.Parse(OpenAIAPIProxy) + if nil != err { + logging.LogErrorf("OpenAI API proxy failed: %v", err) + } else { + config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}} + } + } + return gogpt.NewClientWithConfig(config) +} + func initOpenAI() { OpenAIAPIKey = os.Getenv("SIYUAN_OPENAI_API_KEY") if "" == OpenAIAPIKey {