feat(server): Implement ZPOPMIN and ZPOPMAX #358 #359 (#378)

* Implements ZPOPMIN and ZPOPMAX commands
This commit is contained in:
RedhaL 2022-10-13 14:01:59 +02:00 committed by GitHub
parent 28706715dc
commit 2e875c81c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 204 additions and 2 deletions

View File

@ -3,6 +3,7 @@
* **[Amir Alperin](https://github.com/iko1)** * **[Amir Alperin](https://github.com/iko1)**
* **[Philipp Born](https://github.com/tamcore)** * **[Philipp Born](https://github.com/tamcore)**
* Helm Chart * Helm Chart
* **[Redha Lhimeur](https://github.com/redhal)**
* **[Braydn Moore](https://github.com/braydnm)** * **[Braydn Moore](https://github.com/braydnm)**
* **[Logan Raarup](https://github.com/logandk)** * **[Logan Raarup](https://github.com/logandk)**
* **[Ryan Russell](https://github.com/ryanrussell)** * **[Ryan Russell](https://github.com/ryanrussell)**

View File

@ -125,6 +125,7 @@ OpResult<PrimeIterator> FindZEntry(const ZParams& zparams, const OpArgs& op_args
enum class Action { enum class Action {
RANGE = 0, RANGE = 0,
REMOVE = 1, REMOVE = 1,
POP = 2
}; };
class IntervalVisitor { class IntervalVisitor {
@ -139,6 +140,8 @@ class IntervalVisitor {
void operator()(const ZSetFamily::LexInterval& li); void operator()(const ZSetFamily::LexInterval& li);
void operator()(ZSetFamily::TopNScored sc);
ZSetFamily::ScoredArray PopResult() { ZSetFamily::ScoredArray PopResult() {
return std::move(result_); return std::move(result_);
} }
@ -154,6 +157,9 @@ class IntervalVisitor {
void ExtractListPack(const zlexrangespec& range); void ExtractListPack(const zlexrangespec& range);
void ExtractSkipList(const zlexrangespec& range); void ExtractSkipList(const zlexrangespec& range);
void PopListPack(ZSetFamily::TopNScored sc);
void PopSkipList(ZSetFamily::TopNScored sc);
void ActionRange(unsigned start, unsigned end); // rank void ActionRange(unsigned start, unsigned end); // rank
void ActionRange(const zrangespec& range); // score void ActionRange(const zrangespec& range); // score
void ActionRange(const zlexrangespec& range); // lex void ActionRange(const zlexrangespec& range); // lex
@ -162,6 +168,8 @@ class IntervalVisitor {
void ActionRem(const zrangespec& range); // score void ActionRem(const zrangespec& range); // score
void ActionRem(const zlexrangespec& range); // lex void ActionRem(const zlexrangespec& range); // lex
void ActionPop(ZSetFamily::TopNScored sc);
void Next(uint8_t* zl, uint8_t** eptr, uint8_t** sptr) const { void Next(uint8_t* zl, uint8_t** eptr, uint8_t** sptr) const {
if (params_.reverse) { if (params_.reverse) {
zzlPrev(zl, eptr, sptr); zzlPrev(zl, eptr, sptr);
@ -214,6 +222,8 @@ void IntervalVisitor::operator()(const ZSetFamily::IndexInterval& ii) {
case Action::REMOVE: case Action::REMOVE:
ActionRem(start, end); ActionRem(start, end);
break; break;
default:
break;
} }
} }
@ -227,6 +237,8 @@ void IntervalVisitor::operator()(const ZSetFamily::ScoreInterval& si) {
case Action::REMOVE: case Action::REMOVE:
ActionRem(range); ActionRem(range);
break; break;
default:
break;
} }
} }
@ -240,10 +252,22 @@ void IntervalVisitor::operator()(const ZSetFamily::LexInterval& li) {
case Action::REMOVE: case Action::REMOVE:
ActionRem(range); ActionRem(range);
break; break;
default:
break;
} }
zslFreeLexRange(&range); zslFreeLexRange(&range);
} }
void IntervalVisitor::operator()(ZSetFamily::TopNScored sc) {
switch (action_) {
case Action::POP:
ActionPop(sc);
break;
default:
break;
}
}
void IntervalVisitor::ActionRange(unsigned start, unsigned end) { void IntervalVisitor::ActionRange(unsigned start, unsigned end) {
container_utils::IterateSortedSet(zobj_, [this](container_utils::ContainerEntry ce, double score){ container_utils::IterateSortedSet(zobj_, [this](container_utils::ContainerEntry ce, double score){
result_.emplace_back(ce.ToString(), score); result_.emplace_back(ce.ToString(), score);
@ -311,6 +335,15 @@ void IntervalVisitor::ActionRem(const zlexrangespec& range) {
} }
} }
void IntervalVisitor::ActionPop(ZSetFamily::TopNScored sc) {
if (zobj_->encoding == OBJ_ENCODING_LISTPACK) {
PopListPack(sc);
} else {
CHECK_EQ(zobj_->encoding, OBJ_ENCODING_SKIPLIST);
PopSkipList(sc);
}
}
void IntervalVisitor::ExtractListPack(const zrangespec& range) { void IntervalVisitor::ExtractListPack(const zrangespec& range) {
uint8_t* zl = (uint8_t*)zobj_->ptr; uint8_t* zl = (uint8_t*)zobj_->ptr;
uint8_t *eptr, *sptr; uint8_t *eptr, *sptr;
@ -472,6 +505,67 @@ void IntervalVisitor::ExtractSkipList(const zlexrangespec& range) {
} }
} }
void IntervalVisitor::PopListPack(ZSetFamily::TopNScored sc) {
uint8_t* zl = (uint8_t*)zobj_->ptr;
uint8_t *eptr, *sptr;
uint8_t* vstr;
unsigned int vlen = 0;
long long vlong = 0;
if (params_.reverse) {
eptr = lpSeek(zl,-2);
} else {
eptr = lpSeek(zl,0);
}
/* Get score pointer for the first element. */
if (eptr)
sptr = lpNext(zl, eptr);
/* First we get the entries */
unsigned int num = sc;
while (eptr && num--) {
double score = zzlGetScore(sptr);
vstr = lpGetValue(eptr, &vlen, &vlong);
AddResult(vstr, vlen, vlong, score);
/* Move to next node */
Next(zl, &eptr, &sptr);
}
int start = 0;
if (params_.reverse) {
/* If the number of elements to delete is greater than the listpack length,
* we set the start to 0 because lpseek fails to search beyond length in reverse */
start = (2*sc > lpLength(zl)) ? 0 : -2*sc;
}
/* We can finally delete the elements */
zobj_->ptr = lpDeleteRange(zl, start, 2*sc);
}
void IntervalVisitor::PopSkipList(ZSetFamily::TopNScored sc) {
zset* zs = (zset*)zobj_->ptr;
zskiplist* zsl = zs->zsl;
zskiplistNode* ln;
/* We start from the header, or the tail if reversed. */
if (params_.reverse) {
ln = zsl->tail;
} else {
ln = zsl->header;
}
while (ln && sc--) {
result_.emplace_back(string{ln->ele, sdslen(ln->ele)}, ln->score);
/* we can delete the element now */
zsetDel(zobj_, ln->ele);
ln = Next(ln);
}
}
void IntervalVisitor::AddResult(const uint8_t* vstr, unsigned vlen, long long vlong, double score) { void IntervalVisitor::AddResult(const uint8_t* vstr, unsigned vlen, long long vlong, double score) {
if (vstr == NULL) { if (vstr == NULL) {
result_.emplace_back(absl::StrCat(vlong), score); result_.emplace_back(absl::StrCat(vlong), score);
@ -1078,6 +1172,14 @@ void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendLong(smvec.size()); (*cntx)->SendLong(smvec.size());
} }
void ZSetFamily::ZPopMax(CmdArgList args, ConnectionContext* cntx) {
ZPopMinMax(std::move(args), true, cntx);
}
void ZSetFamily::ZPopMin(CmdArgList args, ConnectionContext* cntx) {
ZPopMinMax(std::move(args), false, cntx);
}
void ZSetFamily::ZLexCount(CmdArgList args, ConnectionContext* cntx) { void ZSetFamily::ZLexCount(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1); string_view key = ArgS(args, 1);
@ -1532,6 +1634,30 @@ bool ZSetFamily::ParseRangeByScoreParams(CmdArgList args, RangeParams* params) {
return true; return true;
} }
void ZSetFamily::ZPopMinMax(CmdArgList args, bool reverse, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
string_view count = ArgS(args, 2);
RangeParams range_params;
range_params.reverse = reverse;
ZRangeSpec range_spec;
range_spec.params = range_params;
TopNScored sc;
if (!SimpleAtoi(count, &sc)) {
return (*cntx)->SendError(kUintErr);
}
range_spec.interval = sc;
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpPopCount(range_spec, t->GetOpArgs(shard), key);
};
OpResult<ScoredArray> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
OutputScoredArrayResult(result, range_params, cntx);
}
OpResult<StringVec> ZSetFamily::OpScan(const OpArgs& op_args, std::string_view key, OpResult<StringVec> ZSetFamily::OpScan(const OpArgs& op_args, std::string_view key,
uint64_t* cursor) { uint64_t* cursor) {
OpResult<PrimeIterator> find_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); OpResult<PrimeIterator> find_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET);
@ -1658,6 +1784,30 @@ OpResult<ZSetFamily::MScoreResponse> ZSetFamily::OpMScore(const OpArgs& op_args,
return scores; return scores;
} }
auto ZSetFamily::OpPopCount(const ZRangeSpec& range_spec, const OpArgs& op_args, string_view key) -> OpResult<ScoredArray> {
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> res_it = db_slice.Find(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
db_slice.PreUpdate(op_args.db_cntx.db_index, *res_it);
robj* zobj = res_it.value()->second.AsRObj();
IntervalVisitor iv{Action::POP, range_spec.params, zobj};
std::visit(iv, range_spec.interval);
res_it.value()->second.SyncRObj();
db_slice.PostUpdate(op_args.db_cntx.db_index, *res_it, key);
auto zlen = zsetLength(zobj);
if (zlen == 0) {
CHECK(op_args.shard->db_slice().Del(op_args.db_cntx.db_index, res_it.value()));
}
return iv.PopResult();
}
auto ZSetFamily::OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, string_view key) auto ZSetFamily::OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, string_view key)
-> OpResult<ScoredArray> { -> OpResult<ScoredArray> {
OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET);
@ -1857,6 +2007,8 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZINCRBY", CO::FAST | CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(ZIncrBy) << CI{"ZINCRBY", CO::FAST | CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(ZIncrBy)
<< CI{"ZINTERSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZInterStore) << CI{"ZINTERSTORE", kUnionMask, -4, 3, 3, 1}.HFUNC(ZInterStore)
<< CI{"ZLEXCOUNT", CO::READONLY, 4, 1, 1, 1}.HFUNC(ZLexCount) << CI{"ZLEXCOUNT", CO::READONLY, 4, 1, 1, 1}.HFUNC(ZLexCount)
<< CI{"ZPOPMAX", CO::READONLY, 3, 1, 1, 1}.HFUNC(ZPopMax)
<< CI{"ZPOPMIN", CO::READONLY, 3, 1, 1, 1}.HFUNC(ZPopMin)
<< CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, 1}.HFUNC(ZRem) << CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, 1}.HFUNC(ZRem)
<< CI{"ZRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRange) << CI{"ZRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRange)
<< CI{"ZRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRank) << CI{"ZRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRank)

View File

@ -34,6 +34,8 @@ class ZSetFamily {
using LexInterval = std::pair<LexBound, LexBound>; using LexInterval = std::pair<LexBound, LexBound>;
using TopNScored = uint32_t;
struct RangeParams { struct RangeParams {
uint32_t offset = 0; uint32_t offset = 0;
uint32_t limit = UINT32_MAX; uint32_t limit = UINT32_MAX;
@ -42,7 +44,7 @@ class ZSetFamily {
}; };
struct ZRangeSpec { struct ZRangeSpec {
std::variant<IndexInterval, ScoreInterval, LexInterval> interval; std::variant<IndexInterval, ScoreInterval, LexInterval, TopNScored> interval;
RangeParams params; RangeParams params;
}; };
@ -58,6 +60,8 @@ class ZSetFamily {
static void ZIncrBy(CmdArgList args, ConnectionContext* cntx); static void ZIncrBy(CmdArgList args, ConnectionContext* cntx);
static void ZInterStore(CmdArgList args, ConnectionContext* cntx); static void ZInterStore(CmdArgList args, ConnectionContext* cntx);
static void ZLexCount(CmdArgList args, ConnectionContext* cntx); static void ZLexCount(CmdArgList args, ConnectionContext* cntx);
static void ZPopMax(CmdArgList args, ConnectionContext* cntx);
static void ZPopMin(CmdArgList args, ConnectionContext* cntx);
static void ZRange(CmdArgList args, ConnectionContext* cntx); static void ZRange(CmdArgList args, ConnectionContext* cntx);
static void ZRank(CmdArgList args, ConnectionContext* cntx); static void ZRank(CmdArgList args, ConnectionContext* cntx);
static void ZRem(CmdArgList args, ConnectionContext* cntx); static void ZRem(CmdArgList args, ConnectionContext* cntx);
@ -84,7 +88,7 @@ class ZSetFamily {
static void ZRangeGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx); static void ZRangeGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx);
static void ZRankGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx); static void ZRankGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx);
static bool ParseRangeByScoreParams(CmdArgList args, RangeParams* params); static bool ParseRangeByScoreParams(CmdArgList args, RangeParams* params);
static void ZPopMinMax(CmdArgList args, bool reverse, ConnectionContext* cntx);
static OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor); static OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor);
static OpResult<unsigned> OpRem(const OpArgs& op_args, std::string_view key, ArgSlice members); static OpResult<unsigned> OpRem(const OpArgs& op_args, std::string_view key, ArgSlice members);
@ -93,6 +97,8 @@ class ZSetFamily {
using MScoreResponse = std::vector<std::optional<double>>; using MScoreResponse = std::vector<std::optional<double>>;
static OpResult<MScoreResponse> OpMScore(const OpArgs& op_args, std::string_view key, static OpResult<MScoreResponse> OpMScore(const OpArgs& op_args, std::string_view key,
ArgSlice members); ArgSlice members);
static OpResult<ScoredArray> OpPopCount(const ZRangeSpec& range_spec, const OpArgs& op_args,
std::string_view key);
static OpResult<ScoredArray> OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, static OpResult<ScoredArray> OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args,
std::string_view key); std::string_view key);
static OpResult<unsigned> OpRemRange(const OpArgs& op_args, std::string_view key, static OpResult<unsigned> OpRemRange(const OpArgs& op_args, std::string_view key,

View File

@ -267,4 +267,47 @@ TEST_F(ZSetFamilyTest, ZAddBug148) {
EXPECT_THAT(resp, IntArg(1)); EXPECT_THAT(resp, IntArg(1));
} }
TEST_F(ZSetFamilyTest, ZPopMin) {
auto resp = Run({"zadd", "key", "1", "a", "2", "b", "3", "c", "4", "d", "5", "e"});
EXPECT_THAT(resp, IntArg(5));
resp = Run({"zpopmin", "key", "2"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "b"));
resp = Run({"zpopmin", "key", "-1"});
ASSERT_THAT(resp, ErrArg("value is out of range, must be positive"));
resp = Run({"zpopmin", "key", "1"});
ASSERT_THAT(resp, "c");
resp = Run({"zpopmin", "key", "3"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec(), ElementsAre("d", "e"));
resp = Run({"zpopmin", "key", "1"});
ASSERT_THAT(resp, ArrLen(0));
}
TEST_F(ZSetFamilyTest, ZPopMax) {
auto resp = Run({"zadd", "key", "1", "a", "2", "b", "3", "c", "4", "d", "5", "e"});
EXPECT_THAT(resp, IntArg(5));
resp = Run({"zpopmax", "key", "2"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec(), ElementsAre("e", "d"));
resp = Run({"zpopmax", "key", "-1"});
ASSERT_THAT(resp, ErrArg("value is out of range, must be positive"));
resp = Run({"zpopmax", "key", "1"});
ASSERT_THAT(resp, "c");
resp = Run({"zpopmax", "key", "3"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec(), ElementsAre("b", "a"));
resp = Run({"zpopmax", "key", "1"});
ASSERT_THAT(resp, ArrLen(0));
}
} // namespace dfly } // namespace dfly