Add ZRANGEBYSCORE. Cover rank case for ZRANGE

This commit is contained in:
Roman Gershman 2022-03-18 05:12:22 +02:00
parent 0611a3e760
commit cb0d8dfee2
5 changed files with 206 additions and 70 deletions

View File

@ -108,8 +108,8 @@ API 1.0
- [X] ZADD
- [X] ZCARD
- [ ] ZINCRBY
- [ ] ZRANGE
- [ ] ZRANGEBYSCORE
- [X] ZRANGE
- [X] ZRANGEBYSCORE
- [X] ZREM
- [ ] ZREMRANGEBYSCORE
- [ ] ZREVRANGE

View File

@ -96,5 +96,6 @@ int zzlLexValueLteMax(unsigned char* p, const zlexrangespec* spec);
int zslLexValueGteMin(sds value, const zlexrangespec* spec);
int zslLexValueLteMax(sds value, const zlexrangespec* spec);
int zsetZiplistValidateIntegrity(unsigned char* zl, size_t size, int deep);
zskiplistNode* zslGetElementByRank(zskiplist *zsl, unsigned long rank);
#endif

View File

@ -53,14 +53,9 @@ OpResult<MainIterator> FindZEntry(unsigned flags, const OpArgs& op_args, string_
return it;
}
struct ZListParams {
uint32_t offset = 0;
uint32_t limit = UINT32_MAX;
};
class IntervalVisitor {
public:
IntervalVisitor(const ZListParams& params, robj* o) : params_(params), zobj_(o) {
IntervalVisitor(const ZSetFamily::RangeParams& params, robj* o) : params_(params), zobj_(o) {
}
void operator()(const ZSetFamily::IndexInterval& ii);
@ -87,7 +82,9 @@ class IntervalVisitor {
return reverse_ ? zslValueGteMin(score, &spec) : zslValueLteMax(score, &spec);
}
ZListParams params_;
void AddResult(const uint8_t* vstr, unsigned vlen, long long vlon, double score);
ZSetFamily::RangeParams params_;
robj* zobj_;
bool reverse_ = false;
@ -95,7 +92,77 @@ class IntervalVisitor {
};
void IntervalVisitor::operator()(const ZSetFamily::IndexInterval& ii) {
LOG(FATAL) << "TBD";
unsigned long llen = zsetLength(zobj_);
int32_t start = ii.first;
int32_t end = ii.second;
if (start < 0)
start = llen + start;
if (end < 0)
end = llen + end;
if (start < 0)
start = 0;
if (start > end || unsigned(start) >= llen) {
return;
}
if (unsigned(end) >= llen)
end = llen - 1;
unsigned rangelen = (end - start) + 1;
if (zobj_->encoding == OBJ_ENCODING_LISTPACK) {
unsigned char* zl = (uint8_t*)zobj_->ptr;
unsigned char *eptr, *sptr;
unsigned char* vstr;
unsigned int vlen;
long long vlong;
double score = 0.0;
if (reverse_)
eptr = lpSeek(zl, -2 - (2 * start));
else
eptr = lpSeek(zl, 2 * start);
sptr = lpNext(zl, eptr);
while (rangelen--) {
DCHECK(eptr != NULL && sptr != NULL);
vstr = lpGetValue(eptr, &vlen, &vlong);
if (params_.with_scores) /* don't bother to extract the score if it's gonna be ignored. */
score = zzlGetScore(sptr);
AddResult(vstr, vlen, vlong, score);
Next(zl, &eptr, &sptr);
}
} else if (zobj_->encoding == OBJ_ENCODING_SKIPLIST) {
zset* zs = (zset*)zobj_->ptr;
zskiplist* zsl = zs->zsl;
zskiplistNode* ln;
/* Check if starting point is trivial, before doing log(N) lookup. */
if (reverse_) {
ln = zsl->tail;
if (start > 0)
ln = zslGetElementByRank(zsl, llen - start);
} else {
ln = zsl->header->level[0].forward;
if (start > 0)
ln = zslGetElementByRank(zsl, start + 1);
}
while (rangelen--) {
DCHECK(ln != NULL);
sds ele = ln->ele;
result_.emplace_back(string(ele, sdslen(ele)), ln->score);
ln = reverse_ ? ln->backward : ln->level[0].forward;
}
} else {
LOG(FATAL) << "Unknown sorted set encoding" << zobj_->encoding;
}
}
void IntervalVisitor::ExtractListPack(const zrangespec& range) {
@ -136,14 +203,9 @@ void IntervalVisitor::ExtractListPack(const zrangespec& range) {
* succeed */
vstr = lpGetValue(eptr, &vlen, &vlong);
rangelen++;
if (vstr == NULL) {
result_.emplace_back(absl::StrCat(vlong), score);
} else {
result_.emplace_back(string{reinterpret_cast<char*>(vstr), vlen}, score);
// handler->emitResultFromCBuffer(handler, vstr, vlen, score);
}
AddResult(vstr, vlen, vlong, score);
rangelen++;
/* Move to next node */
Next(zl, &eptr, &sptr);
}
@ -196,7 +258,7 @@ void IntervalVisitor::operator()(const ZSetFamily::ScoreInterval& si) {
range.min = si.first.val;
range.max = si.second.val;
range.minex = si.first.is_open;
range.maxex = si.first.is_open;
range.maxex = si.second.is_open;
if (zobj_->encoding == OBJ_ENCODING_LISTPACK) {
ExtractListPack(range);
@ -207,17 +269,37 @@ void IntervalVisitor::operator()(const ZSetFamily::ScoreInterval& si) {
}
}
bool ParseScore(string_view src, double* d) {
if (src == "-inf") {
*d = -HUGE_VAL;
} else if (src == "+inf") {
*d = HUGE_VAL;
void IntervalVisitor::AddResult(const uint8_t* vstr, unsigned vlen, long long vlong, double score) {
if (vstr == NULL) {
result_.emplace_back(absl::StrCat(vlong), score);
} else {
return absl::SimpleAtod(src, d);
result_.emplace_back(string{reinterpret_cast<const char*>(vstr), vlen}, score);
}
}
bool ParseScore(string_view src, double* score) {
if (src == "-inf") {
*score = -HUGE_VAL;
} else if (src == "+inf") {
*score = HUGE_VAL;
} else {
return absl::SimpleAtod(src, score);
}
return true;
};
bool ParseBound(string_view src, ZSetFamily::Bound* bound) {
if (src.empty())
return false;
if (src[0] == '(') {
bound->is_open = true;
src.remove_prefix(1);
}
return ParseScore(src, &bound->val);
}
} // namespace
void ZSetFamily::ZCard(CmdArgList args, ConnectionContext* cntx) {
@ -331,12 +413,8 @@ void ZSetFamily::ZRange(CmdArgList args, ConnectionContext* cntx) {
std::string_view min_s = ArgS(args, 2);
std::string_view max_s = ArgS(args, 3);
if (min_s.empty() || max_s.empty()) {
return (*cntx)->SendError(kInvalidIntErr);
}
ZRangeSpec range_spec;
bool parse_score = false;
RangeParams range_params;
for (size_t i = 4; i < args.size(); ++i) {
ToUpper(&args[i]);
@ -344,58 +422,57 @@ void ZSetFamily::ZRange(CmdArgList args, ConnectionContext* cntx) {
string_view cur_arg = ArgS(args, i);
if (cur_arg == "BYSCORE") {
parse_score = true;
} else if (cur_arg == "WITHSCORES") {
range_params.with_scores = true;
} else {
return cntx->reply_builder()->SendError(absl::StrCat("unsupported option ", cur_arg));
}
}
if (parse_score) {
ScoreInterval si;
if (min_s[0] == '(') {
si.first.is_open = true;
min_s.remove_prefix(1);
}
if (max_s[0] == '(') {
si.second.is_open = true;
max_s.remove_prefix(1);
}
if (!ParseScore(min_s, &si.first.val) || !ParseScore(max_s, &si.second.val)) {
return (*cntx)->SendError("min or max is not a float");
}
range_spec.interval = si;
} else {
IndexInterval ii;
if (!absl::SimpleAtoi(min_s, &ii.first) || !absl::SimpleAtoi(max_s, &ii.second)) {
(*cntx)->SendError(kInvalidIntErr);
return;
}
range_spec.interval = ii;
ZRangeByScoreInternal(key, min_s, max_s, range_params, cntx);
return;
}
IndexInterval ii;
if (!absl::SimpleAtoi(min_s, &ii.first) || !absl::SimpleAtoi(max_s, &ii.second)) {
(*cntx)->SendError(kInvalidIntErr);
return;
}
ZRangeSpec range_spec;
range_spec.params = range_params;
range_spec.interval = ii;
auto cb = [&](Transaction* t, EngineShard* shard) {
OpArgs op_args{shard, t->db_index()};
return OpRange(range_spec, op_args, key);
};
OpResult<ScoredArray> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result.status() == OpStatus::WRONG_TYPE) {
(*cntx)->SendError(kWrongTypeErr);
} else {
(*cntx)->StartArray(result.value().size());
for (const auto& p : result.value()) {
(*cntx)->SendBulkString(p.first);
if (false) { // withscores
(*cntx)->SendDouble(p.second);
}
}
}
OpResult<ScoredArray> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
OutputScoredArrayResult(result, range_params.with_scores, cntx);
}
void ZSetFamily::ZRangeByScore(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view min_s = ArgS(args, 2);
std::string_view max_s = ArgS(args, 3);
RangeParams range_params;
for (size_t i = 4; i < args.size(); ++i) {
ToUpper(&args[i]);
string_view cur_arg = ArgS(args, i);
if (cur_arg == "WITHSCORES") {
range_params.with_scores = true;
} else {
return cntx->reply_builder()->SendError(absl::StrCat("unsupported option ", cur_arg));
}
}
ZRangeByScoreInternal(key, min_s, max_s, range_params, cntx);
}
void ZSetFamily::ZRem(CmdArgList args, ConnectionContext* cntx) {
@ -438,6 +515,47 @@ void ZSetFamily::ZScore(CmdArgList args, ConnectionContext* cntx) {
}
}
void ZSetFamily::ZRangeByScoreInternal(std::string_view key, std::string_view min_s,
std::string_view max_s, const RangeParams& params,
ConnectionContext* cntx) {
ZRangeSpec range_spec;
range_spec.params = params;
ScoreInterval si;
if (!ParseBound(min_s, &si.first) ||
!ParseBound(max_s, &si.second)) {
return (*cntx)->SendError("min or max is not a float");
}
range_spec.interval = si;
auto cb = [&](Transaction* t, EngineShard* shard) {
OpArgs op_args{shard, t->db_index()};
return OpRange(range_spec, op_args, key);
};
OpResult<ScoredArray> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
OutputScoredArrayResult(result, params.with_scores, cntx);
}
void ZSetFamily::OutputScoredArrayResult(const OpResult<ScoredArray>& result, bool with_scores,
ConnectionContext* cntx) {
if (result.status() == OpStatus::WRONG_TYPE) {
return (*cntx)->SendError(kWrongTypeErr);
}
LOG_IF(WARNING, !result && result.status() != OpStatus::KEY_NOTFOUND)
<< "Unexpected status " << result.status();
(*cntx)->StartArray(result->size() * (with_scores ? 2 : 1));
for (const auto& p : result.value()) {
(*cntx)->SendBulkString(p.first);
if (with_scores) {
(*cntx)->SendDouble(p.second);
}
}
}
OpResult<unsigned> ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_args, string_view key,
ScoredMemberSpan members) {
DCHECK(!members.empty());
@ -523,8 +641,7 @@ auto ZSetFamily::OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, st
return res_it.status();
robj* zobj = res_it.value()->second.AsRObj();
ZListParams params;
IntervalVisitor iv{params, zobj};
IntervalVisitor iv{range_spec.params, zobj};
absl::visit(iv, range_spec.interval);

View File

@ -27,15 +27,23 @@ class ZSetFamily {
using ScoreInterval = std::pair<Bound, Bound>;
struct RangeParams {
uint32_t offset = 0;
uint32_t limit = UINT32_MAX;
bool with_scores = false;
};
struct ZRangeSpec {
std::variant<IndexInterval, ScoreInterval> interval;
// TODO: handle open/close, inf etc.
RangeParams params;
};
using ScoredMember = std::pair<std::string, double>;
using ScoredArray = std::vector<ScoredMember>;
private:
template <typename T> using OpResult = facade::OpResult<T>;
static void ZCard(CmdArgList args, ConnectionContext* cntx);
static void ZAdd(CmdArgList args, ConnectionContext* cntx);
static void ZIncrBy(CmdArgList args, ConnectionContext* cntx);
@ -44,6 +52,12 @@ class ZSetFamily {
static void ZScore(CmdArgList args, ConnectionContext* cntx);
static void ZRangeByScore(CmdArgList args, ConnectionContext* cntx);
static void ZRangeByScoreInternal(std::string_view key, std::string_view min_s,
std::string_view max_s, const RangeParams& params,
ConnectionContext* cntx);
static void OutputScoredArrayResult(const OpResult<ScoredArray>& arr, bool with_scores,
ConnectionContext* cntx);
struct ZParams {
unsigned flags = 0; // mask of ZADD_IN_ macros.
bool ch = false; // Corresponds to CH option.
@ -51,7 +65,6 @@ class ZSetFamily {
using ScoredMemberView = std::pair<double, std::string_view>;
using ScoredMemberSpan = absl::Span<ScoredMemberView>;
template <typename T> using OpResult = facade::OpResult<T>;
static OpResult<unsigned> OpAdd(const ZParams& zparams, const OpArgs& op_args,
std::string_view key, ScoredMemberSpan members);
@ -60,7 +73,6 @@ class ZSetFamily {
std::string_view member);
static OpResult<ScoredArray> OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args,
std::string_view key);
};
} // namespace dfly

View File

@ -48,10 +48,16 @@ TEST_F(ZSetFamilyTest, ZRem) {
resp = Run({"zrem", "x", "b", "c"});
EXPECT_THAT(resp[0], IntArg(1));
resp = Run({"zcard", "x"});
EXPECT_THAT(resp[0], IntArg(1));
EXPECT_THAT(Run({"zrange", "x", "0", "3", "byscore"}), ElementsAre("a"));
EXPECT_THAT(Run({"zrange", "x", "(-inf", "(+inf", "byscore"}), ElementsAre("a"));
}
TEST_F(ZSetFamilyTest, ZRange) {
Run({"zadd", "x", "1.1", "a", "2.1", "b"});
EXPECT_THAT(Run({"zrangebyscore", "x", "0", "(1.1"}), ElementsAre(ArrLen(0)));
}
} // namespace dfly