Add ZRANK,ZCOUNT,ZREVRANK.

1. Fix #12 - return number of added items for non-increment usecase.
2. Fix #15 - fix double precision response. I use a different printing algorithm that of Redis
   therefore there could be string differences between 2 systems. However, both replies should
   be equivalent numerically.
3. Fix #13.  Reject ZADD with LT and GT options together.
4. Fix #11 - return correct error when parsing invalid scores.
This commit is contained in:
Roman Gershman 2022-04-06 22:54:10 +03:00
parent 92ebb74500
commit fa70267729
8 changed files with 261 additions and 61 deletions

View File

@ -29,6 +29,12 @@ add_third_party(
<SOURCE_DIR>/luaconf.h ${THIRD_PARTY_LIB_DIR}/lua/include
)
add_third_party(
dconv
URL https://github.com/google/double-conversion/archive/refs/tags/v3.2.0.tar.gz
LIB libdouble-conversion.a
)
Message(STATUS "THIRD_PARTY_LIB_DIR ${THIRD_PARTY_LIB_DIR}")
include_directories(src)

View File

@ -205,15 +205,15 @@ API 2.0
- [X] Set Family
- [X] SSCAN
- [X] Sorted Set Family
- [ ] ZCOUNT
- [X] ZCOUNT
- [ ] ZINTERSTORE
- [ ] ZLEXCOUNT
- [ ] ZRANGEBYLEX
- [ ] ZRANK
- [X] ZRANK
- [ ] ZREMRANGEBYLEX
- [X] ZREMRANGEBYRANK
- [ ] ZREVRANGEBYSCORE
- [ ] ZREVRANK
- [X] ZREVRANK
- [ ] ZUNIONSTORE
- [ ] ZSCAN
- [ ] HYPERLOGLOG Family

View File

@ -1,7 +1,7 @@
add_library(dfly_facade dragonfly_listener.cc dragonfly_connection.cc facade.cc
memcache_parser.cc redis_parser.cc reply_builder.cc)
cxx_link(dfly_facade base uring_fiber_lib fibers_ext strings_lib http_server_lib
tls_lib TRDP::mimalloc)
tls_lib TRDP::mimalloc TRDP::dconv)
add_library(facade_test facade_test.cc)
cxx_link(facade_test dfly_facade gtest_main_ext)

View File

@ -5,12 +5,14 @@
#include <absl/strings/numbers.h>
#include <absl/strings/str_cat.h>
#include <double-conversion/double-to-string.h>
#include "base/logging.h"
#include "facade/error.h"
using namespace std;
using absl::StrAppend;
using namespace double_conversion;
namespace facade {
@ -100,7 +102,6 @@ void MCReplyBuilder::SendSimpleString(std::string_view str) {
Send(v, ABSL_ARRAYSIZE(v));
}
void MCReplyBuilder::SendStored() {
SendSimpleString("STORED");
}
@ -116,8 +117,7 @@ void MCReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) {
for (unsigned i = 0; i < count; ++i) {
if (resp[i]) {
const auto& src = *resp[i];
absl::StrAppend(&header, "VALUE ", src.key, " ", src.mc_flag, " ",
src.value.size());
absl::StrAppend(&header, "VALUE ", src.key, " ", src.mc_flag, " ", src.value.size());
if (src.mc_ver) {
absl::StrAppend(&header, " ", src.mc_ver);
}
@ -224,7 +224,11 @@ void RedisReplyBuilder::SendLong(long num) {
}
void RedisReplyBuilder::SendDouble(double val) {
SendBulkString(absl::StrCat(val));
char buf[64];
StringBuilder sb(buf, sizeof(buf));
CHECK(DoubleToStringConverter::EcmaScriptConverter().ToShortest(val, &sb));
SendBulkString(sb.Finalize());
}
void RedisReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) {
@ -286,4 +290,4 @@ void ReqSerializer::SendCommand(std::string_view str) {
ec_ = sink_->Write(v, ABSL_ARRAYSIZE(v));
}
} // namespace dfly
} // namespace facade

View File

@ -75,7 +75,7 @@ unsigned long zsetLength(const robj* zobj);
void zsetConvert(robj* zobj, int encoding);
void zsetConvertToZiplistIfNeeded(robj* zobj, size_t maxelelen);
int zsetScore(robj* zobj, sds member, double* score);
// unsigned long zslGetRank(zskiplist *zsl, double score, sds o);
unsigned long zslGetRank(zskiplist *zsl, double score, sds ele);
int zsetAdd(robj* zobj, double score, sds ele, int in_flags, int* out_flags, double* newscore);
long zsetRank(robj* zobj, sds ele, int reverse);
int zsetDel(robj* zobj, sds ele);

View File

@ -7,6 +7,7 @@
extern "C" {
#include "redis/listpack.h"
#include "redis/object.h"
#include "redis/util.h"
#include "redis/zset.h"
}
@ -28,10 +29,22 @@ using CI = CommandId;
static const char kNxXxErr[] = "XX and NX options at the same time are not compatible";
static const char kScoreNaN[] = "resulting score is not a number (NaN)";
static const char kRangeErr[] = "min or max is not a float";
constexpr unsigned kMaxListPackValue = 64;
inline zrangespec GetZrangeSpec(const ZSetFamily::ScoreInterval& si) {
zrangespec range;
range.min = si.first.val;
range.max = si.second.val;
range.minex = si.first.is_open;
range.maxex = si.second.is_open;
return range;
}
OpResult<PrimeIterator> FindZEntry(unsigned flags, const OpArgs& op_args, string_view key,
size_t member_len) {
size_t member_len) {
auto& db_slice = op_args.shard->db_slice();
if (flags & ZADD_IN_XX) {
return db_slice.Find(op_args.db_ind, key, OBJ_ZSET);
@ -138,11 +151,7 @@ void IntervalVisitor::operator()(const ZSetFamily::IndexInterval& ii) {
}
void IntervalVisitor::operator()(const ZSetFamily::ScoreInterval& si) {
zrangespec range;
range.min = si.first.val;
range.max = si.second.val;
range.minex = si.first.is_open;
range.maxex = si.second.is_open;
zrangespec range = GetZrangeSpec(si);
switch (action_) {
case Action::RANGE:
@ -158,9 +167,9 @@ void IntervalVisitor::ActionRange(unsigned start, unsigned end) {
unsigned rangelen = (end - start) + 1;
if (zobj_->encoding == OBJ_ENCODING_LISTPACK) {
unsigned char* zl = (uint8_t*)zobj_->ptr;
unsigned char *eptr, *sptr;
unsigned char* vstr;
uint8_t* zl = (uint8_t*)zobj_->ptr;
uint8_t *eptr, *sptr;
uint8_t* vstr;
unsigned int vlen;
long long vlong;
double score = 0.0;
@ -349,12 +358,15 @@ void IntervalVisitor::AddResult(const uint8_t* vstr, unsigned vlen, long long vl
}
bool ParseScore(string_view src, double* score) {
if (src.empty())
return false;
if (src == "-inf") {
*score = -HUGE_VAL;
} else if (src == "+inf") {
*score = HUGE_VAL;
} else {
return absl::SimpleAtod(src, score);
return string2d(src.data(), src.size(), score);
}
return true;
};
@ -373,27 +385,6 @@ bool ParseBound(string_view src, ZSetFamily::Bound* bound) {
} // namespace
void ZSetFamily::ZCard(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<uint32_t> {
OpResult<PrimeIterator> find_res = shard->db_slice().Find(t->db_index(), key, OBJ_ZSET);
if (!find_res) {
return find_res.status();
}
return zsetLength(find_res.value()->second.AsRObj());
};
OpResult<uint32_t> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result.status() == OpStatus::WRONG_TYPE) {
(*cntx)->SendError(kWrongTypeErr);
return;
}
(*cntx)->SendLong(result.value());
}
void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
@ -437,7 +428,9 @@ void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
return;
}
if ((zparams.flags & ZADD_IN_NX) && (zparams.flags & (ZADD_IN_GT | ZADD_IN_LT))) {
constexpr auto kRangeOpt = ZADD_IN_GT | ZADD_IN_LT;
if (((zparams.flags & ZADD_IN_NX) && (zparams.flags & kRangeOpt)) ||
((zparams.flags & kRangeOpt) == kRangeOpt)) {
(*cntx)->SendError("GT, LT, and/or NX options at the same time are not compatible");
return;
}
@ -445,7 +438,8 @@ void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
absl::InlinedVector<ScoredMemberView, 4> members;
for (; i < args.size(); i += 2) {
string_view cur_arg = ArgS(args, i);
double val;
double val = 0;
if (!ParseScore(cur_arg, &val)) {
return (*cntx)->SendError(kInvalidFloatErr);
}
@ -471,18 +465,66 @@ void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
}
// KEY_NOTFOUND may happen in case of XX flag.
if (status == OpStatus::SKIPPED || status == OpStatus::KEY_NOTFOUND) {
return (*cntx)->SendNull();
}
if (add_result.is_nan) {
return (*cntx)->SendError(kScoreNaN);
}
if (zparams.flags & ZADD_IN_INCR) {
(*cntx)->SendDouble(add_result.new_score);
if (status == OpStatus::KEY_NOTFOUND) {
if (zparams.flags & ZADD_IN_INCR)
(*cntx)->SendNull();
else
(*cntx)->SendLong(0);
} else if (status == OpStatus::SKIPPED) {
(*cntx)->SendNull();
} else if (add_result.is_nan) {
(*cntx)->SendError(kScoreNaN);
} else {
(*cntx)->SendLong(add_result.num_updated);
if (zparams.flags & ZADD_IN_INCR) {
(*cntx)->SendDouble(add_result.new_score);
} else {
(*cntx)->SendLong(add_result.num_updated);
}
}
}
void ZSetFamily::ZCard(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<uint32_t> {
OpResult<PrimeIterator> find_res = shard->db_slice().Find(t->db_index(), key, OBJ_ZSET);
if (!find_res) {
return find_res.status();
}
return zsetLength(find_res.value()->second.AsRObj());
};
OpResult<uint32_t> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result.status() == OpStatus::WRONG_TYPE) {
(*cntx)->SendError(kWrongTypeErr);
return;
}
(*cntx)->SendLong(result.value());
}
void ZSetFamily::ZCount(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
string_view min_s = ArgS(args, 2);
string_view max_s = ArgS(args, 3);
ScoreInterval si;
if (!ParseBound(min_s, &si.first) || !ParseBound(max_s, &si.second)) {
return (*cntx)->SendError(kRangeErr);
}
auto cb = [&](Transaction* t, EngineShard* shard) {
OpArgs op_args{shard, t->db_index()};
return OpCount(op_args, key, si);
};
OpResult<unsigned> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result.status() == OpStatus::WRONG_TYPE) {
(*cntx)->SendError(kWrongTypeErr);
} else {
(*cntx)->SendLong(*result);
}
}
@ -531,10 +573,18 @@ void ZSetFamily::ZRange(CmdArgList args, ConnectionContext* cntx) {
ZRangeGeneric(std::move(args), false, cntx);
}
void ZSetFamily::ZRank(CmdArgList args, ConnectionContext* cntx) {
ZRankGeneric(std::move(args), false, cntx);
}
void ZSetFamily::ZRevRange(CmdArgList args, ConnectionContext* cntx) {
ZRangeGeneric(std::move(args), true, cntx);
}
void ZSetFamily::ZRevRank(CmdArgList args, ConnectionContext* cntx) {
ZRankGeneric(std::move(args), true, cntx);
}
void ZSetFamily::ZRangeByScore(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
string_view min_s = ArgS(args, 2);
@ -578,7 +628,7 @@ void ZSetFamily::ZRemRangeByScore(CmdArgList args, ConnectionContext* cntx) {
ScoreInterval si;
if (!ParseBound(min_s, &si.first) || !ParseBound(max_s, &si.second)) {
return (*cntx)->SendError("min or max is not a float");
return (*cntx)->SendError(kRangeErr);
}
ZRangeSpec range_spec;
@ -635,7 +685,7 @@ void ZSetFamily::ZRangeByScoreInternal(string_view key, string_view min_s, strin
ScoreInterval si;
if (!ParseBound(min_s, &si.first) || !ParseBound(max_s, &si.second)) {
return (*cntx)->SendError("min or max is not a float");
return (*cntx)->SendError(kRangeErr);
}
range_spec.interval = si;
@ -730,6 +780,26 @@ void ZSetFamily::ZRangeGeneric(CmdArgList args, bool reverse, ConnectionContext*
OutputScoredArrayResult(result, range_params, cntx);
}
void ZSetFamily::ZRankGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
string_view member = ArgS(args, 2);
auto cb = [&](Transaction* t, EngineShard* shard) {
OpArgs op_args{shard, t->db_index()};
return OpRank(op_args, key, member, reverse);
};
OpResult<unsigned> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result) {
(*cntx)->SendLong(*result);
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
(*cntx)->SendNull();
} else {
(*cntx)->SendError(result.status());
}
}
OpStatus ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_args, string_view key,
ScoredMemberSpan members, AddResult* add_result) {
DCHECK(!members.empty());
@ -859,19 +929,108 @@ OpResult<unsigned> ZSetFamily::OpRemRange(const OpArgs& op_args, string_view key
return iv.removed();
}
OpResult<unsigned> ZSetFamily::OpRank(const OpArgs& op_args, string_view key, string_view member,
bool reverse) {
OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
robj* zobj = res_it.value()->second.AsRObj();
op_args.shard->tmp_str1 = sdscpylen(op_args.shard->tmp_str1, member.data(), member.size());
long res = zsetRank(zobj, op_args.shard->tmp_str1, reverse);
if (res < 0)
return OpStatus::KEY_NOTFOUND;
return res;
}
OpResult<unsigned> ZSetFamily::OpCount(const OpArgs& op_args, std::string_view key,
const ScoreInterval& interval) {
OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
robj* zobj = res_it.value()->second.AsRObj();
zrangespec range = GetZrangeSpec(interval);
unsigned count = 0;
if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
uint8_t* zl = (uint8_t*)zobj->ptr;
uint8_t *eptr, *sptr;
double score;
/* Use the first element in range as the starting point */
eptr = zzlFirstInRange(zl, &range);
/* No "first" element */
if (eptr == NULL) {
return 0;
}
/* First element is in range */
sptr = lpNext(zl, eptr);
score = zzlGetScore(sptr);
DCHECK(zslValueLteMax(score, &range));
/* Iterate over elements in range */
while (eptr) {
score = zzlGetScore(sptr);
/* Abort when the node is no longer in range. */
if (!zslValueLteMax(score, &range)) {
break;
} else {
count++;
zzlNext(zl, &eptr, &sptr);
}
}
} else {
CHECK_EQ(unsigned(OBJ_ENCODING_SKIPLIST), zobj->encoding);
zset* zs = (zset*)zobj->ptr;
zskiplist* zsl = zs->zsl;
zskiplistNode* zn;
unsigned long rank;
/* Find first element in range */
zn = zslFirstInRange(zsl, &range);
/* Use rank of first element, if any, to determine preliminary count */
if (zn == NULL)
return 0;
rank = zslGetRank(zsl, zn->score, zn->ele);
count = (zsl->length - (rank - 1));
/* Find last element in range */
zn = zslLastInRange(zsl, &range);
/* Use rank of last element, if any, to determine the actual count */
if (zn != NULL) {
rank = zslGetRank(zsl, zn->score, zn->ele);
count -= (zsl->length - rank);
}
}
return count;
}
#define HFUNC(x) SetHandler(&ZSetFamily::x)
void ZSetFamily::Register(CommandRegistry* registry) {
*registry << CI{"ZCARD", CO::FAST | CO::READONLY, 2, 1, 1, 1}.HFUNC(ZCard)
<< CI{"ZADD", CO::FAST | CO::WRITE | CO::DENYOOM, -4, 1, 1, 1}.HFUNC(ZAdd)
*registry << CI{"ZADD", CO::FAST | CO::WRITE | CO::DENYOOM, -4, 1, 1, 1}.HFUNC(ZAdd)
<< CI{"ZCARD", CO::FAST | CO::READONLY, 2, 1, 1, 1}.HFUNC(ZCard)
<< CI{"ZCOUNT", CO::FAST | CO::READONLY, 4, 1, 1, 1}.HFUNC(ZCount)
<< CI{"ZINCRBY", CO::FAST | CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(ZIncrBy)
<< CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, 1}.HFUNC(ZRem)
<< CI{"ZRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRange)
<< CI{"ZRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRank)
<< CI{"ZRANGEBYSCORE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRangeByScore)
<< CI{"ZSCORE", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZScore)
<< CI{"ZREMRANGEBYRANK", CO::WRITE, 4, 1, 1, 1}.HFUNC(ZRemRangeByRank)
<< CI{"ZREMRANGEBYSCORE", CO::WRITE, 4, 1, 1, 1}.HFUNC(ZRemRangeByScore)
<< CI{"ZREVRANGE", CO::WRITE, 4, 1, 1, 1}.HFUNC(ZRevRange);
<< CI{"ZREVRANGE", CO::WRITE, 4, 1, 1, 1}.HFUNC(ZRevRange)
<< CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZRevRank);
}
} // namespace dfly

View File

@ -45,16 +45,19 @@ class ZSetFamily {
private:
template <typename T> using OpResult = facade::OpResult<T>;
static void ZCard(CmdArgList args, ConnectionContext* cntx);
static void ZAdd(CmdArgList args, ConnectionContext* cntx);
static void ZCard(CmdArgList args, ConnectionContext* cntx);
static void ZCount(CmdArgList args, ConnectionContext* cntx);
static void ZIncrBy(CmdArgList args, ConnectionContext* cntx);
static void ZRange(CmdArgList args, ConnectionContext* cntx);
static void ZRank(CmdArgList args, ConnectionContext* cntx);
static void ZRem(CmdArgList args, ConnectionContext* cntx);
static void ZScore(CmdArgList args, ConnectionContext* cntx);
static void ZRangeByScore(CmdArgList args, ConnectionContext* cntx);
static void ZRemRangeByRank(CmdArgList args, ConnectionContext* cntx);
static void ZRemRangeByScore(CmdArgList args, ConnectionContext* cntx);
static void ZRevRange(CmdArgList args, ConnectionContext* cntx);
static void ZRevRank(CmdArgList args, ConnectionContext* cntx);
static void ZRangeByScoreInternal(std::string_view key, std::string_view min_s,
std::string_view max_s, const RangeParams& params,
@ -64,6 +67,7 @@ class ZSetFamily {
static void ZRemRangeGeneric(std::string_view key, const ZRangeSpec& range_spec,
ConnectionContext* cntx);
static void ZRangeGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx);
static void ZRankGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx);
struct ZParams {
unsigned flags = 0; // mask of ZADD_IN_ macros.
@ -89,6 +93,12 @@ class ZSetFamily {
std::string_view key);
static OpResult<unsigned> OpRemRange(const OpArgs& op_args, std::string_view key,
const ZRangeSpec& spec);
static OpResult<unsigned> OpRank(const OpArgs& op_args, std::string_view key,
std::string_view member, bool reverse);
static OpResult<unsigned> OpCount(const OpArgs& op_args, std::string_view key,
const ScoreInterval& interval);
};
} // namespace dfly

View File

@ -40,6 +40,16 @@ TEST_F(ZSetFamilyTest, Add) {
resp = Run({"zcard", "x"});
EXPECT_THAT(resp[0], IntArg(1));
EXPECT_THAT(Run({"zadd", "x", "", "a"}), ElementsAre(ErrArg("not a valid float")));
EXPECT_THAT(Run({"zadd", "ztmp", "xx", "10", "member"}), ElementsAre(IntArg(0)));
const char kHighPrecision[] = "0.79028573343077946";
Run({"zadd", "zs", kHighPrecision, "a"});
EXPECT_THAT(Run({"zscore", "zs", "a"}), ElementsAre("0.7902857334307795"));
EXPECT_EQ(0.79028573343077946, 0.7902857334307795);
}
TEST_F(ZSetFamilyTest, ZRem) {
@ -55,10 +65,21 @@ TEST_F(ZSetFamilyTest, ZRem) {
EXPECT_THAT(Run({"zrange", "x", "(-inf", "(+inf", "byscore"}), ElementsAre("a"));
}
TEST_F(ZSetFamilyTest, ZRange) {
TEST_F(ZSetFamilyTest, ZRangeRank) {
Run({"zadd", "x", "1.1", "a", "2.1", "b"});
EXPECT_THAT(Run({"zrangebyscore", "x", "0", "(1.1"}), ElementsAre(ArrLen(0)));
EXPECT_THAT(Run({"zrangebyscore", "x", "-inf", "1.1"}), ElementsAre("a"));
EXPECT_EQ(2, CheckedInt({"zcount", "x", "1.1", "2.1"}));
EXPECT_EQ(1, CheckedInt({"zcount", "x", "(1.1", "2.1"}));
EXPECT_EQ(0, CheckedInt({"zcount", "y", "(1.1", "2.1"}));
EXPECT_EQ(0, CheckedInt({"zrank", "x", "a"}));
EXPECT_EQ(1, CheckedInt({"zrank", "x", "b"}));
EXPECT_EQ(1, CheckedInt({"zrevrank", "x", "a"}));
EXPECT_EQ(0, CheckedInt({"zrevrank", "x", "b"}));
EXPECT_THAT(Run({"zrevrank", "x", "c"}), ElementsAre(ArgType(RespExpr::NIL)));
EXPECT_THAT(Run({"zrank", "y", "c"}), ElementsAre(ArgType(RespExpr::NIL)));
}
TEST_F(ZSetFamilyTest, ZRemRangeRank) {