From a48154ba845476695a8dd76e6543e2f6d87487d5 Mon Sep 17 00:00:00 2001 From: Liang Ding Date: Thu, 4 May 2023 10:11:29 +0800 Subject: [PATCH] :art: API `/api/query/sql` add `LIMIT` clause https://github.com/siyuan-note/siyuan/issues/8167 --- kernel/api/sql.go | 3 +- kernel/model/search.go | 8 ++-- kernel/sql/block_query.go | 88 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 91 insertions(+), 8 deletions(-) diff --git a/kernel/api/sql.go b/kernel/api/sql.go index 50c194fc0..0fd180f5c 100644 --- a/kernel/api/sql.go +++ b/kernel/api/sql.go @@ -21,6 +21,7 @@ import ( "github.com/88250/gulu" "github.com/gin-gonic/gin" + "github.com/siyuan-note/siyuan/kernel/model" "github.com/siyuan-note/siyuan/kernel/sql" "github.com/siyuan-note/siyuan/kernel/util" ) @@ -35,7 +36,7 @@ func SQL(c *gin.Context) { } stmt := arg["stmt"].(string) - result, err := sql.Query(stmt) + result, err := sql.Query(stmt, model.Conf.Search.Limit) if nil != err { ret.Code = 1 ret.Msg = err.Error() diff --git a/kernel/model/search.go b/kernel/model/search.go index f6a94beb9..30ce519d4 100644 --- a/kernel/model/search.go +++ b/kernel/model/search.go @@ -624,7 +624,7 @@ func searchBySQL(stmt string, beforeLen, page int) (ret []*Block, matchedBlockCo stmt = strings.ReplaceAll(stmt, "select * ", "select COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` ") } stmt = removeLimitClause(stmt) - result, _ := sql.Query(stmt) + result, _ := sql.QueryNoLimit(stmt) if 1 > len(ret) { return } @@ -745,7 +745,7 @@ func fullTextSearchCountByRegexp(exp, boxFilter, pathFilter, typeFilter string) fieldFilter := fieldRegexp(exp) stmt := "SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `blocks` WHERE " + fieldFilter + " AND type IN " + typeFilter stmt += boxFilter + pathFilter - result, _ := sql.Query(stmt) + result, _ := sql.QueryNoLimit(stmt) if 1 > len(result) { return } @@ -785,7 +785,7 @@ func fullTextSearchByFTS(query, boxFilter, pathFilter, typeFilter, orderBy strin func fullTextSearchCount(query, boxFilter, pathFilter, typeFilter string) (matchedBlockCount, matchedRootCount int) { query = gulu.Str.RemoveInvisible(query) if ast.IsNodeIDPattern(query) { - ret, _ := sql.Query("SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `blocks` WHERE `id` = '" + query + "'") + ret, _ := sql.QueryNoLimit("SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `blocks` WHERE `id` = '" + query + "'") if 1 > len(ret) { return } @@ -802,7 +802,7 @@ func fullTextSearchCount(query, boxFilter, pathFilter, typeFilter string) (match stmt := "SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `" + table + "` WHERE (`" + table + "` MATCH '" + columnFilter() + ":(" + query + ")'" stmt += ") AND type IN " + typeFilter stmt += boxFilter + pathFilter - result, _ := sql.Query(stmt) + result, _ := sql.QueryNoLimit(stmt) if 1 > len(result) { return } diff --git a/kernel/sql/block_query.go b/kernel/sql/block_query.go index b1085e1e4..f5e214cf0 100644 --- a/kernel/sql/block_query.go +++ b/kernel/sql/block_query.go @@ -19,6 +19,7 @@ package sql import ( "bytes" "database/sql" + "math" "sort" "strconv" "strings" @@ -378,7 +379,45 @@ func QueryBookmarkLabels() (ret []string) { return } -func Query(stmt string) (ret []map[string]interface{}, err error) { +func QueryNoLimit(stmt string) (ret []map[string]interface{}, err error) { + return queryRawStmt(stmt, math.MaxInt) +} + +func Query(stmt string, limit int) (ret []map[string]interface{}, err error) { + parsedStmt, err := sqlparser.Parse(stmt) + if nil != err { + return queryRawStmt(stmt, limit) + } + + switch parsedStmt.(type) { + case *sqlparser.Select: + slct := parsedStmt.(*sqlparser.Select) + if nil == slct.Limit { + slct.Limit = &sqlparser.Limit{ + Rowcount: &sqlparser.SQLVal{ + Type: sqlparser.IntVal, + Val: []byte(strconv.Itoa(limit)), + }, + } + } else { + if nil != slct.Limit.Rowcount && 0 < len(slct.Limit.Rowcount.(*sqlparser.SQLVal).Val) { + limit, _ = strconv.Atoi(string(slct.Limit.Rowcount.(*sqlparser.SQLVal).Val)) + if 0 >= limit { + limit = 32 + } + } + + slct.Limit.Rowcount = &sqlparser.SQLVal{ + Type: sqlparser.IntVal, + Val: []byte(strconv.Itoa(limit)), + } + } + + stmt = sqlparser.String(slct) + default: + return + } + ret = []map[string]interface{}{} rows, err := query(stmt) if nil != err { @@ -413,6 +452,49 @@ func Query(stmt string) (ret []map[string]interface{}, err error) { return } +func queryRawStmt(stmt string, limit int) (ret []map[string]interface{}, err error) { + rows, err := query(stmt) + if nil != err { + if strings.Contains(err.Error(), "syntax error") { + return + } + return + } + defer rows.Close() + + cols, err := rows.Columns() + if nil != err || nil == cols { + return + } + + noLimit := !strings.Contains(strings.ToLower(stmt), " limit ") + var count, errCount int + for rows.Next() { + columns := make([]interface{}, len(cols)) + columnPointers := make([]interface{}, len(cols)) + for i := range columns { + columnPointers[i] = &columns[i] + } + + if err = rows.Scan(columnPointers...); nil != err { + return + } + + m := make(map[string]interface{}) + for i, colName := range cols { + val := columnPointers[i].(*interface{}) + m[colName] = *val + } + + ret = append(ret, m) + count++ + if (noLimit && limit < count) || 0 < errCount { + break + } + } + return +} + func SelectBlocksRawStmtNoParse(stmt string, limit int) (ret []*Block) { return selectBlocksRawStmt(stmt, limit) } @@ -491,7 +573,7 @@ func selectBlocksRawStmt(stmt string, limit int) (ret []*Block) { } defer rows.Close() - confLimit := !strings.Contains(strings.ToLower(stmt), " limit ") + noLimit := !strings.Contains(strings.ToLower(stmt), " limit ") var count, errCount int for rows.Next() { count++ @@ -502,7 +584,7 @@ func selectBlocksRawStmt(stmt string, limit int) (ret []*Block) { errCount++ } - if (confLimit && limit < count) || 0 < errCount { + if (noLimit && limit < count) || 0 < errCount { break } }