Add ZSCAN command

This commit is contained in:
Roman Gershman 2022-04-25 00:14:40 +03:00
parent b50428d1a6
commit d64a0c6f0e
4 changed files with 125 additions and 2 deletions

View File

@ -897,6 +897,7 @@ OpResult<StringVec> HSetFamily::OpScan(const OpArgs& op_args, std::string_view k
res.emplace_back(reinterpret_cast<char*>(elem), size_t(ele_len));
lp_elem = lpNext(lp, lp_elem); // switch to value
} while (lp_elem);
*cursor = 0;
} else {
dict* ht = (dict*)hset->ptr;
long maxiterations = count * 10;

View File

@ -11,6 +11,8 @@ extern "C" {
#include "redis/zset.h"
}
#include <double-conversion/double-to-string.h>
#include "base/logging.h"
#include "facade/error.h"
#include "server/command_registry.h"
@ -20,6 +22,7 @@ extern "C" {
namespace dfly {
using namespace double_conversion;
using namespace std;
using namespace facade;
using absl::SimpleAtoi;
@ -321,7 +324,7 @@ void IntervalVisitor::ActionRem(const zrangespec& range) {
}
void IntervalVisitor::ActionRem(const zlexrangespec& range) {
if (zobj_->encoding == OBJ_ENCODING_LISTPACK) {
if (zobj_->encoding == OBJ_ENCODING_LISTPACK) {
uint8_t* zl = (uint8_t*)zobj_->ptr;
unsigned long deleted = 0;
zl = zzlDeleteRangeByLex(zl, &range, &deleted);
@ -960,6 +963,37 @@ void ZSetFamily::ZScore(CmdArgList args, ConnectionContext* cntx) {
}
}
void ZSetFamily::ZScan(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
string_view token = ArgS(args, 2);
uint64_t cursor = 0;
if (!absl::SimpleAtoi(token, &cursor)) {
return (*cntx)->SendError("invalid cursor");
}
if (args.size() > 3) {
return (*cntx)->SendError("scan options are not supported yet");
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpScan(OpArgs{shard, t->db_index()}, key, &cursor);
};
OpResult<StringVec> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result.status() != OpStatus::WRONG_TYPE) {
(*cntx)->StartArray(2);
(*cntx)->SendSimpleString(absl::StrCat(cursor));
(*cntx)->StartArray(result->size());
for (const auto& k : *result) {
(*cntx)->SendBulkString(k);
}
} else {
(*cntx)->SendError(result.status());
}
}
void ZSetFamily::ZRangeByScoreInternal(string_view key, string_view min_s, string_view max_s,
const RangeParams& params, ConnectionContext* cntx) {
ZRangeSpec range_spec;
@ -1082,6 +1116,69 @@ void ZSetFamily::ZRankGeneric(CmdArgList args, bool reverse, ConnectionContext*
}
}
OpResult<StringVec> ZSetFamily::OpScan(const OpArgs& op_args, std::string_view key,
uint64_t* cursor) {
OpResult<PrimeIterator> find_res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_ZSET);
if (!find_res)
return find_res.status();
PrimeIterator it = find_res.value();
StringVec res;
robj* zobj = it->second.AsRObj();
char buf[128];
if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
RangeParams params;
IntervalVisitor iv{Action::RANGE, params, zobj};
iv(IndexInterval{0, kuint32max});
ScoredArray arr = iv.PopResult();
res.resize(arr.size() * 2);
for (size_t i = 0; i < arr.size(); ++i) {
StringBuilder sb(buf, sizeof(buf));
CHECK(DoubleToStringConverter::EcmaScriptConverter().ToShortest(arr[i].second, &sb));
res[2 * i] = std::move(arr[i].first);
res[2 * i + 1].assign(sb.Finalize());
}
*cursor = 0;
} else {
CHECK_EQ(unsigned(OBJ_ENCODING_SKIPLIST), zobj->encoding);
uint32_t count = 20;
zset* zs = (zset*)zobj->ptr;
dict* ht = zs->dict;
long maxiterations = count * 10;
struct ScanArgs {
char* sbuf;
StringVec* res;
} sargs = {buf, &res};
auto scanCb = [](void* privdata, const dictEntry* de) {
ScanArgs* sargs = (ScanArgs*)privdata;
sds key = (sds)de->key;
double score = *(double*)dictGetVal(de);
sargs->res->emplace_back(key, sdslen(key));
StringBuilder sb(sargs->sbuf, sizeof(buf));
CHECK(DoubleToStringConverter::EcmaScriptConverter().ToShortest(score, &sb));
sargs->res->emplace_back(sb.Finalize());
};
do {
*cursor = dictScan(ht, *cursor, scanCb, NULL, &sargs);
} while (*cursor && maxiterations-- && res.size() < count);
}
return res;
}
OpStatus ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_args, string_view key,
ScoredMemberSpan members, AddResult* add_result) {
DCHECK(!members.empty());
@ -1379,7 +1476,8 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZREMRANGEBYLEX", CO::WRITE, 4, 1, 1, 1}.HFUNC(ZRemRangeByLex)
<< CI{"ZREVRANGE", CO::READONLY, 4, 1, 1, 1}.HFUNC(ZRevRange)
<< CI{"ZREVRANGEBYSCORE", CO::READONLY, 4, 1, 1, 1}.HFUNC(ZRevRangeByScore)
<< CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRevRank);
<< CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRevRank)
<< CI{"ZSCAN", CO::READONLY | CO::RANDOM, -3, 1, 1, 1}.HFUNC(ZScan);
}
} // namespace dfly

View File

@ -69,6 +69,7 @@ class ZSetFamily {
static void ZRevRange(CmdArgList args, ConnectionContext* cntx);
static void ZRevRangeByScore(CmdArgList args, ConnectionContext* cntx);
static void ZRevRank(CmdArgList args, ConnectionContext* cntx);
static void ZScan(CmdArgList args, ConnectionContext* cntx);
static void ZRangeByScoreInternal(std::string_view key, std::string_view min_s,
std::string_view max_s, const RangeParams& params,
@ -79,6 +80,7 @@ class ZSetFamily {
ConnectionContext* cntx);
static void ZRangeGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx);
static void ZRankGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx);
static OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor);
struct ZParams {
unsigned flags = 0; // mask of ZADD_IN_ macros.

View File

@ -134,4 +134,26 @@ TEST_F(ZSetFamilyTest, ByLex) {
ASSERT_THAT(resp.GetVec(), ElementsAre("alpha", "bar", "cool", "down", "elephant", "foo"));
}
TEST_F(ZSetFamilyTest, ZScan) {
string prefix(128,'a');
for (unsigned i = 0; i < 100; ++i) {
Run({"zadd", "key", "1", absl::StrCat(prefix, i)});
}
EXPECT_EQ(100, CheckedInt({"zcard", "key"}));
int64_t cursor = 0;
size_t scan_len = 0;
do {
auto resp = Run({"zscan", "key", absl::StrCat(cursor)});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
ASSERT_THAT(resp.GetVec(), ElementsAre(ArgType(RespExpr::STRING), ArgType(RespExpr::ARRAY)));
string_view token = ToSV(resp.GetVec()[0].GetBuf());
ASSERT_TRUE(absl::SimpleAtoi(token, &cursor));
auto sub_arr = resp.GetVec()[1].GetVec();
scan_len += sub_arr.size();
} while (cursor != 0);
EXPECT_EQ(100 * 2, scan_len);
}
} // namespace dfly