🎨 AI 加入上下文信息

This commit is contained in:
Liang Ding 2023-03-05 00:20:07 +08:00
parent 3c9e80b411
commit 5eae66a35e
No known key found for this signature in database
GPG key ID: 136F30F901A2231D

View file

@ -40,7 +40,15 @@ var (
OpenAIAPIMaxTokens = 0 OpenAIAPIMaxTokens = 0
) )
var cachedContextMsg []string
func ChatGPT(msg string) (ret string) { func ChatGPT(msg string) (ret string) {
ret, retCtxMsgs := ChatGPTContinueWrite(msg, cachedContextMsg)
cachedContextMsg = append(cachedContextMsg, retCtxMsgs...)
return
}
func ChatGPTContinueWrite(msg string, contextMsgs []string) (ret string, retContextMsgs []string) {
if "" == OpenAIAPIKey { if "" == OpenAIAPIKey {
return return
} }
@ -48,30 +56,31 @@ func ChatGPT(msg string) (ret string) {
PushEndlessProgress("Requesting...") PushEndlessProgress("Requesting...")
defer ClearPushProgress(100) defer ClearPushProgress(100)
config := gogpt.DefaultConfig(OpenAIAPIKey) c := newOpenAIClient()
if "" != OpenAIAPIProxy {
proxyUrl, err := url.Parse(OpenAIAPIProxy) var reqMsgs []gogpt.ChatCompletionMessage
if nil != err { if 7 < len(contextMsgs) {
logging.LogErrorf("OpenAI API proxy error: %v", err) contextMsgs = contextMsgs[len(contextMsgs)-7:]
} else {
config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
}
} }
c := gogpt.NewClientWithConfig(config) for _, ctxMsg := range contextMsgs {
ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout) reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{
defer cancel() Role: "user",
Content: ctxMsg,
})
}
reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{
Role: "user",
Content: msg,
})
req := gogpt.ChatCompletionRequest{ req := gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo, Model: gogpt.GPT3Dot5Turbo,
MaxTokens: OpenAIAPIMaxTokens, MaxTokens: OpenAIAPIMaxTokens,
Messages: []gogpt.ChatCompletionMessage{ Messages: reqMsgs,
{
Role: "user",
Content: msg,
},
},
} }
ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout)
defer cancel()
stream, err := c.CreateChatCompletionStream(ctx, req) stream, err := c.CreateChatCompletionStream(ctx, req)
if nil != err { if nil != err {
logging.LogErrorf("create chat completion stream failed: %s", err) logging.LogErrorf("create chat completion stream failed: %s", err)
@ -100,9 +109,23 @@ func ChatGPT(msg string) (ret string) {
ret = buf.String() ret = buf.String()
ret = strings.TrimSpace(ret) ret = strings.TrimSpace(ret)
retContextMsgs = append(retContextMsgs, msg, ret)
return return
} }
func newOpenAIClient() *gogpt.Client {
config := gogpt.DefaultConfig(OpenAIAPIKey)
if "" != OpenAIAPIProxy {
proxyUrl, err := url.Parse(OpenAIAPIProxy)
if nil != err {
logging.LogErrorf("OpenAI API proxy failed: %v", err)
} else {
config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
}
}
return gogpt.NewClientWithConfig(config)
}
func initOpenAI() { func initOpenAI() {
OpenAIAPIKey = os.Getenv("SIYUAN_OPENAI_API_KEY") OpenAIAPIKey = os.Getenv("SIYUAN_OPENAI_API_KEY")
if "" == OpenAIAPIKey { if "" == OpenAIAPIKey {