diff --git a/kernel/sql/block_query.go b/kernel/sql/block_query.go index f5e214cf0..290e0fa98 100644 --- a/kernel/sql/block_query.go +++ b/kernel/sql/block_query.go @@ -391,31 +391,17 @@ func Query(stmt string, limit int) (ret []map[string]interface{}, err error) { switch parsedStmt.(type) { case *sqlparser.Select: + limitClause := getLimitClause(parsedStmt, limit) slct := parsedStmt.(*sqlparser.Select) - if nil == slct.Limit { - slct.Limit = &sqlparser.Limit{ - Rowcount: &sqlparser.SQLVal{ - Type: sqlparser.IntVal, - Val: []byte(strconv.Itoa(limit)), - }, - } - } else { - if nil != slct.Limit.Rowcount && 0 < len(slct.Limit.Rowcount.(*sqlparser.SQLVal).Val) { - limit, _ = strconv.Atoi(string(slct.Limit.Rowcount.(*sqlparser.SQLVal).Val)) - if 0 >= limit { - limit = 32 - } - } - - slct.Limit.Rowcount = &sqlparser.SQLVal{ - Type: sqlparser.IntVal, - Val: []byte(strconv.Itoa(limit)), - } - } - + slct.Limit = limitClause stmt = sqlparser.String(slct) + case *sqlparser.Union: + limitClause := getLimitClause(parsedStmt, limit) + union := parsedStmt.(*sqlparser.Union) + union.Limit = limitClause + stmt = sqlparser.String(union) default: - return + return queryRawStmt(stmt, limit) } ret = []map[string]interface{}{} @@ -452,6 +438,31 @@ func Query(stmt string, limit int) (ret []map[string]interface{}, err error) { return } +func getLimitClause(parsedStmt sqlparser.Statement, limit int) (ret *sqlparser.Limit) { + switch parsedStmt.(type) { + case *sqlparser.Select: + slct := parsedStmt.(*sqlparser.Select) + if nil != slct.Limit { + ret = slct.Limit + } + case *sqlparser.Union: + union := parsedStmt.(*sqlparser.Union) + if nil != union.Limit { + ret = union.Limit + } + } + + if nil == ret || nil == ret.Rowcount { + ret = &sqlparser.Limit{ + Rowcount: &sqlparser.SQLVal{ + Type: sqlparser.IntVal, + Val: []byte(strconv.Itoa(limit)), + }, + } + } + return +} + func queryRawStmt(stmt string, limit int) (ret []map[string]interface{}, err error) { rows, err := query(stmt) if nil != err {