diff --git a/kernel/api/search.go b/kernel/api/search.go index cba275d0e..385dd303e 100644 --- a/kernel/api/search.go +++ b/kernel/api/search.go @@ -212,7 +212,12 @@ func fullTextSearchBlock(c *gin.Context) { if nil != querySyntaxArg { querySyntax = querySyntaxArg.(bool) } - blocks, matchedBlockCount, matchedRootCount := model.FullTextSearchBlock(query, box, path, types, querySyntax) + groupByArg := arg["groupBy"] + var groupBy int // 0:不分组,1:按文档分组 + if nil != groupByArg { + groupBy = int(groupByArg.(float64)) + } + blocks, matchedBlockCount, matchedRootCount := model.FullTextSearchBlock(query, box, path, types, querySyntax, groupBy) ret.Data = map[string]interface{}{ "blocks": blocks, "matchedBlockCount": matchedBlockCount, diff --git a/kernel/model/search.go b/kernel/model/search.go index 070ba0eac..b24dc5ee7 100644 --- a/kernel/model/search.go +++ b/kernel/model/search.go @@ -253,13 +253,41 @@ func FindReplace(keyword, replacement string, ids []string) (err error) { return } -func FullTextSearchBlock(query, box, path string, types map[string]bool, querySyntax bool) (ret []*Block, matchedBlockCount, matchedRootCount int) { +func FullTextSearchBlock(query, box, path string, types map[string]bool, querySyntax bool, groupBy int) (ret []*Block, matchedBlockCount, matchedRootCount int) { query = strings.TrimSpace(query) + beforeLen := 36 + var blocks []*Block if queryStrLower := strings.ToLower(query); strings.Contains(queryStrLower, "select ") && strings.Contains(queryStrLower, " * ") && strings.Contains(queryStrLower, " from ") { - ret, matchedBlockCount, matchedRootCount = searchBySQL(query, 36) + blocks, matchedBlockCount, matchedRootCount = searchBySQL(query, beforeLen) } else { filter := searchFilter(types) - ret, matchedBlockCount, matchedRootCount = fullTextSearch(query, box, path, filter, 36, querySyntax) + blocks, matchedBlockCount, matchedRootCount = fullTextSearch(query, box, path, filter, beforeLen, querySyntax) + } + + switch groupBy { + case 0: // 不分组 + ret = blocks + case 1: // 按文档分组 + rootMap := map[string]bool{} + var rootIDs []string + for _, b := range blocks { + if _, ok := rootMap[b.RootID]; !ok { + rootMap[b.RootID] = true + rootIDs = append(rootIDs, b.RootID) + } + } + sqlRoots := sql.GetBlocks(rootIDs) + roots := fromSQLBlocks(&sqlRoots, "", beforeLen) + for _, root := range roots { + for _, b := range blocks { + if b.RootID == root.ID { + root.Children = append(root.Children, b) + } + } + } + ret = roots + default: + ret = blocks } return } diff --git a/kernel/sql/block_query.go b/kernel/sql/block_query.go index 7f3356e28..a1a4b1214 100644 --- a/kernel/sql/block_query.go +++ b/kernel/sql/block_query.go @@ -598,11 +598,29 @@ func GetAllRootBlocks() (ret []*Block) { } func GetBlocks(ids []string) (ret []*Block) { - length := len(ids) + var notHitIDs []string + cached := map[string]*Block{} + for _, id := range ids { + b := getBlockCache(id) + if nil != b { + cached[id] = b + } else { + notHitIDs = append(notHitIDs, id) + } + } + + if 1 > len(notHitIDs) { + for _, id := range ids { + ret = append(ret, cached[id]) + } + return + } + + length := len(notHitIDs) stmtBuilder := bytes.Buffer{} stmtBuilder.WriteString("SELECT * FROM blocks WHERE id IN (") var args []interface{} - for i, id := range ids { + for i, id := range notHitIDs { args = append(args, id) stmtBuilder.WriteByte('?') if i < length-1 { @@ -619,10 +637,13 @@ func GetBlocks(ids []string) (ret []*Block) { defer rows.Close() for rows.Next() { if block := scanBlockRows(rows); nil != block { - ret = append(ret, block) putBlockCache(block) + cached[block.ID] = block } } + for _, id := range ids { + ret = append(ret, cached[id]) + } return }