🎨 AI 加入上下文信息

This commit is contained in:
Liang Ding 2023-03-05 00:20:07 +08:00
parent 3c9e80b411
commit 5eae66a35e
No known key found for this signature in database
GPG key ID: 136F30F901A2231D

View file

@ -40,7 +40,15 @@ var (
OpenAIAPIMaxTokens = 0
)
var cachedContextMsg []string
func ChatGPT(msg string) (ret string) {
ret, retCtxMsgs := ChatGPTContinueWrite(msg, cachedContextMsg)
cachedContextMsg = append(cachedContextMsg, retCtxMsgs...)
return
}
func ChatGPTContinueWrite(msg string, contextMsgs []string) (ret string, retContextMsgs []string) {
if "" == OpenAIAPIKey {
return
}
@ -48,30 +56,31 @@ func ChatGPT(msg string) (ret string) {
PushEndlessProgress("Requesting...")
defer ClearPushProgress(100)
config := gogpt.DefaultConfig(OpenAIAPIKey)
if "" != OpenAIAPIProxy {
proxyUrl, err := url.Parse(OpenAIAPIProxy)
if nil != err {
logging.LogErrorf("OpenAI API proxy error: %v", err)
} else {
config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
}
c := newOpenAIClient()
var reqMsgs []gogpt.ChatCompletionMessage
if 7 < len(contextMsgs) {
contextMsgs = contextMsgs[len(contextMsgs)-7:]
}
c := gogpt.NewClientWithConfig(config)
ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout)
defer cancel()
for _, ctxMsg := range contextMsgs {
reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{
Role: "user",
Content: ctxMsg,
})
}
reqMsgs = append(reqMsgs, gogpt.ChatCompletionMessage{
Role: "user",
Content: msg,
})
req := gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
MaxTokens: OpenAIAPIMaxTokens,
Messages: []gogpt.ChatCompletionMessage{
{
Role: "user",
Content: msg,
},
},
Messages: reqMsgs,
}
ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout)
defer cancel()
stream, err := c.CreateChatCompletionStream(ctx, req)
if nil != err {
logging.LogErrorf("create chat completion stream failed: %s", err)
@ -100,9 +109,23 @@ func ChatGPT(msg string) (ret string) {
ret = buf.String()
ret = strings.TrimSpace(ret)
retContextMsgs = append(retContextMsgs, msg, ret)
return
}
func newOpenAIClient() *gogpt.Client {
config := gogpt.DefaultConfig(OpenAIAPIKey)
if "" != OpenAIAPIProxy {
proxyUrl, err := url.Parse(OpenAIAPIProxy)
if nil != err {
logging.LogErrorf("OpenAI API proxy failed: %v", err)
} else {
config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
}
}
return gogpt.NewClientWithConfig(config)
}
func initOpenAI() {
OpenAIAPIKey = os.Getenv("SIYUAN_OPENAI_API_KEY")
if "" == OpenAIAPIKey {