diff --git a/kernel/model/ai.go b/kernel/model/ai.go index 409f8b77a..bdc5ea1c4 100644 --- a/kernel/model/ai.go +++ b/kernel/model/ai.go @@ -77,7 +77,7 @@ func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret str if cloud { gpt = &CloudGPT{} } else { - gpt = &OpenAIGPT{c: util.NewOpenAIClient()} + gpt = &OpenAIGPT{c: util.NewOpenAIClient(Conf.AI.OpenAI.APIKey, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIBaseURL)} } buf := &bytes.Buffer{} @@ -99,7 +99,7 @@ func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret str } func isOpenAIAPIEnabled() bool { - if "" == util.OpenAIAPIKey { + if "" == Conf.AI.OpenAI.APIKey { util.PushMsg(Conf.Language(193), 5000) return false } @@ -155,7 +155,7 @@ type OpenAIGPT struct { } func (gpt *OpenAIGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) { - return util.ChatGPT(msg, contextMsgs, gpt.c) + return util.ChatGPT(msg, contextMsgs, gpt.c, Conf.AI.OpenAI.APIMaxTokens, Conf.AI.OpenAI.APITimeout) } type CloudGPT struct { diff --git a/kernel/model/conf.go b/kernel/model/conf.go index 7566ae72f..5a238ce0c 100644 --- a/kernel/model/conf.go +++ b/kernel/model/conf.go @@ -323,6 +323,15 @@ func InitConf() { Conf.AI = conf.NewAI() } + if "" != Conf.AI.OpenAI.APIKey { + logging.LogInfof("OpenAI API enabled\n"+ + " baseURL=%s\n"+ + " timeout=%ds\n"+ + " proxy=%s\n"+ + " maxTokens=%d", + Conf.AI.OpenAI.APIBaseURL, Conf.AI.OpenAI.APITimeout, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIMaxTokens) + } + Conf.ReadOnly = util.ReadOnly if "" != util.AccessAuthCode { diff --git a/kernel/util/openai.go b/kernel/util/openai.go index 335222879..6fddc51b6 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -18,26 +18,15 @@ package util import ( "context" - "net/http" - "net/url" - "os" - "strconv" - "strings" - "time" - gogpt "github.com/sashabaranov/go-gpt3" "github.com/siyuan-note/logging" + "net/http" + "net/url" + "strings" + "time" ) -var ( - OpenAIAPIKey = "" - OpenAIAPITimeout = 30 * time.Second - OpenAIAPIProxy = "" - OpenAIAPIMaxTokens = 0 - OpenAIAPIBaseURL = "https://api.openai.com/v1" -) - -func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, stop bool, err error) { +func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client, maxTokens, timeout int) (ret string, stop bool, err error) { var reqMsgs []gogpt.ChatCompletionMessage for _, ctxMsg := range contextMsgs { @@ -53,10 +42,10 @@ func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, sto req := gogpt.ChatCompletionRequest{ Model: gogpt.GPT3Dot5Turbo, - MaxTokens: OpenAIAPIMaxTokens, + MaxTokens: maxTokens, Messages: reqMsgs, } - ctx, cancel := context.WithTimeout(context.Background(), OpenAIAPITimeout) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) defer cancel() resp, err := c.CreateChatCompletion(ctx, req) if nil != err { @@ -85,10 +74,10 @@ func ChatGPT(msg string, contextMsgs []string, c *gogpt.Client) (ret string, sto return } -func NewOpenAIClient() *gogpt.Client { - config := gogpt.DefaultConfig(OpenAIAPIKey) - if "" != OpenAIAPIProxy { - proxyUrl, err := url.Parse(OpenAIAPIProxy) +func NewOpenAIClient(apiKey, apiProxy, apiBaseURL string) *gogpt.Client { + config := gogpt.DefaultConfig(apiKey) + if "" != apiProxy { + proxyUrl, err := url.Parse(apiProxy) if nil != err { logging.LogErrorf("OpenAI API proxy failed: %v", err) } else { @@ -96,46 +85,6 @@ func NewOpenAIClient() *gogpt.Client { } } - config.BaseURL = OpenAIAPIBaseURL + config.BaseURL = apiBaseURL return gogpt.NewClientWithConfig(config) } - -func initOpenAI() { - OpenAIAPIKey = os.Getenv("SIYUAN_OPENAI_API_KEY") - if "" == OpenAIAPIKey { - return - } - - timeout := os.Getenv("SIYUAN_OPENAI_API_TIMEOUT") - if "" != timeout { - timeoutInt, err := strconv.Atoi(timeout) - if nil == err { - OpenAIAPITimeout = time.Duration(timeoutInt) * time.Second - } - } - - proxy := os.Getenv("SIYUAN_OPENAI_API_PROXY") - if "" != proxy { - OpenAIAPIProxy = proxy - } - - maxTokens := os.Getenv("SIYUAN_OPENAI_API_MAX_TOKENS") - if "" != maxTokens { - maxTokensInt, err := strconv.Atoi(maxTokens) - if nil == err { - OpenAIAPIMaxTokens = maxTokensInt - } - } - - baseURL := os.Getenv("SIYUAN_OPENAI_API_BASE_URL") - if "" != baseURL { - OpenAIAPIBaseURL = baseURL - } - - logging.LogInfof("OpenAI API enabled\n"+ - " baseURL=%s\n"+ - " timeout=%ds\n"+ - " proxy=%s\n"+ - " maxTokens=%d", - OpenAIAPIBaseURL, int(OpenAIAPITimeout.Seconds()), OpenAIAPIProxy, OpenAIAPIMaxTokens) -} diff --git a/kernel/util/working.go b/kernel/util/working.go index fe0c1e7b2..e7c0656dc 100644 --- a/kernel/util/working.go +++ b/kernel/util/working.go @@ -118,8 +118,6 @@ func Boot() { bootBanner := figure.NewColorFigure("SiYuan", "isometric3", "green", true) logging.LogInfof("\n" + bootBanner.String()) logBootInfo() - - initOpenAI() } func setBootDetails(details string) {