diff --git a/kernel/go.mod b/kernel/go.mod index d049c5937..e0f720cdf 100644 --- a/kernel/go.mod +++ b/kernel/go.mod @@ -139,6 +139,7 @@ require ( github.com/richardlehane/mscfb v1.0.4 // indirect github.com/richardlehane/msoleps v1.0.3 // indirect github.com/rivo/uniseg v0.4.4 // indirect + github.com/rqlite/sql v0.0.0-20221103124402-8f9ff0ceb8f0 // indirect github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/shopspring/decimal v1.3.1 // indirect diff --git a/kernel/go.sum b/kernel/go.sum index d81330578..0d38f819b 100644 --- a/kernel/go.sum +++ b/kernel/go.sum @@ -335,6 +335,8 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rqlite/sql v0.0.0-20221103124402-8f9ff0ceb8f0 h1:C8DZB5okjhCSd7zvkOM+zxGz7S6ulUFIL34bpkqFk+0= +github.com/rqlite/sql v0.0.0-20221103124402-8f9ff0ceb8f0/go.mod h1:ib9zVtNgRKiGuoMyUqqL5aNpk+r+++YlyiVIkclVqPg= github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 h1:OkMGxebDjyw0ULyrTYWeN0UNCCkmCWfjPnIA2W6oviI= github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06/go.mod h1:+ePHsJ1keEjQtpvf9HHw0f4ZeJ0TLRsxhunSI2hYJSs= github.com/sashabaranov/go-openai v1.17.6 h1:hYXRPM1xO6QLOJhWEOMlSg/l3jERiKDKd1qIoK22lvs= diff --git a/kernel/sql/block_query.go b/kernel/sql/block_query.go index 26197ba5d..385d1d2de 100644 --- a/kernel/sql/block_query.go +++ b/kernel/sql/block_query.go @@ -27,6 +27,7 @@ import ( "github.com/88250/lute/ast" "github.com/88250/vitess-sqlparser/sqlparser" "github.com/emirpasic/gods/sets/hashset" + sqlparser2 "github.com/rqlite/sql" "github.com/siyuan-note/logging" "github.com/siyuan-note/siyuan/kernel/treenode" "github.com/siyuan-note/siyuan/kernel/util" @@ -384,24 +385,48 @@ func QueryNoLimit(stmt string) (ret []map[string]interface{}, err error) { } func Query(stmt string, limit int) (ret []map[string]interface{}, err error) { - parsedStmt, err := sqlparser.Parse(stmt) + // Kernel API `/api/query/sql` support `||` operator https://github.com/siyuan-note/siyuan/issues/9662 + // 这里为了支持 || 操作符,使用了另一个 sql 解析器,但是这个解析器无法处理 UNION https://github.com/siyuan-note/siyuan/issues/8226 + // 考虑到 UNION 的使用场景不多,这里还是以支持 || 操作符为主 + p := sqlparser2.NewParser(strings.NewReader(stmt)) + parsedStmt2, err := p.ParseStatement() if nil != err { - return queryRawStmt(stmt, limit) - } + if !strings.Contains(stmt, "||") { + // 这个解析器无法处理 || 连接字符串操作符 + parsedStmt, err2 := sqlparser.Parse(stmt) + if nil != err2 { + return queryRawStmt(stmt, limit) + } - switch parsedStmt.(type) { - case *sqlparser.Select: - limitClause := getLimitClause(parsedStmt, limit) - slct := parsedStmt.(*sqlparser.Select) - slct.Limit = limitClause - stmt = sqlparser.String(slct) - case *sqlparser.Union: - limitClause := getLimitClause(parsedStmt, limit) - union := parsedStmt.(*sqlparser.Union) - union.Limit = limitClause - stmt = sqlparser.String(union) - default: - return queryRawStmt(stmt, limit) + switch parsedStmt.(type) { + case *sqlparser.Select: + limitClause := getLimitClause(parsedStmt, limit) + slct := parsedStmt.(*sqlparser.Select) + slct.Limit = limitClause + stmt = sqlparser.String(slct) + case *sqlparser.Union: + // Kernel API `/api/query/sql` support `UNION` statement https://github.com/siyuan-note/siyuan/issues/8226 + limitClause := getLimitClause(parsedStmt, limit) + union := parsedStmt.(*sqlparser.Union) + union.Limit = limitClause + stmt = sqlparser.String(union) + default: + return queryRawStmt(stmt, limit) + } + } else { + return queryRawStmt(stmt, limit) + } + } else { + switch parsedStmt2.(type) { + case *sqlparser2.SelectStatement: + slct := parsedStmt2.(*sqlparser2.SelectStatement) + if nil == slct.LimitExpr { + slct.LimitExpr = &sqlparser2.NumberLit{Value: strconv.Itoa(limit)} + } + stmt = slct.String() + default: + return queryRawStmt(stmt, limit) + } } ret = []map[string]interface{}{}