Implement ZINCRBY/ZADD INCR

This commit is contained in:
Roman Gershman 2022-03-30 14:25:42 +03:00
parent 39ef7bf630
commit b9c1288c67
4 changed files with 132 additions and 52 deletions

View File

@ -113,7 +113,7 @@ API 1.0
- [X] SortedSet Family
- [X] ZADD
- [X] ZCARD
- [ ] ZINCRBY
- [X] ZINCRBY
- [X] ZRANGE
- [X] ZRANGEBYSCORE
- [X] ZREM
@ -187,11 +187,11 @@ API 2.0
- [ ] HVALS
- [ ] HSCAN
- [ ] PubSub family
- [ ] PUBLISH
- [X] PUBLISH
- [ ] PUBSUB
- [ ] PUBSUB CHANNELS
- [ ] SUBSCRIBE
- [ ] UNSUBSCRIBE
- [X] SUBSCRIBE
- [X] UNSUBSCRIBE
- [ ] Server Family
- [ ] WATCH
- [ ] UNWATCH

View File

@ -27,6 +27,7 @@ namespace {
using CI = CommandId;
static const char kNxXxErr[] = "XX and NX options at the same time are not compatible";
static const char kScoreNaN[] = "resulting score is not a number (NaN)";
constexpr unsigned kMaxListPackValue = 64;
OpResult<MainIterator> FindZEntry(unsigned flags, const OpArgs& op_args, string_view key,
@ -394,14 +395,14 @@ void ZSetFamily::ZCard(CmdArgList args, ConnectionContext* cntx) {
}
void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
string_view key = ArgS(args, 1);
ZParams zparams;
size_t i = 2;
for (; i < args.size() - 1; ++i) {
ToUpper(&args[i]);
std::string_view cur_arg = ArgS(args, i);
string_view cur_arg = ArgS(args, i);
if (cur_arg == "XX") {
zparams.flags |= ZADD_IN_XX; // update only
@ -443,39 +444,87 @@ void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
absl::InlinedVector<ScoredMemberView, 4> members;
for (; i < args.size(); i += 2) {
std::string_view cur_arg = ArgS(args, i);
string_view cur_arg = ArgS(args, i);
double val;
if (!ParseScore(cur_arg, &val)) {
(*cntx)->SendError(kInvalidFloatErr);
return;
return (*cntx)->SendError(kInvalidFloatErr);
}
std::string_view member = ArgS(args, i + 1);
if (isnan(val)) {
return (*cntx)->SendError(kScoreNaN);
}
string_view member = ArgS(args, i + 1);
members.emplace_back(val, member);
}
DCHECK(cntx->transaction);
if (zparams.flags & ZADD_IN_INCR) {
LOG(FATAL) << "TBD";
return;
}
absl::Span memb_sp{members.data(), members.size()};
AddResult add_result;
auto cb = [&](Transaction* t, EngineShard* shard) {
auto cb = [&](Transaction* t, EngineShard* shard) -> OpStatus {
OpArgs op_args{shard, t->db_index()};
return OpAdd(zparams, op_args, key, memb_sp);
return OpAdd(zparams, op_args, key, memb_sp, &add_result);
};
OpResult<unsigned> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result.status() == OpStatus::WRONG_TYPE) {
(*cntx)->SendError(kWrongTypeErr);
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
if (status == OpStatus::WRONG_TYPE) {
return (*cntx)->SendError(kWrongTypeErr);
}
// KEY_NOTFOUND may happen in case of XX flag.
if (status == OpStatus::SKIPPED || status == OpStatus::KEY_NOTFOUND) {
return (*cntx)->SendNull();
}
if (add_result.is_nan) {
return (*cntx)->SendError(kScoreNaN);
}
if (zparams.flags & ZADD_IN_INCR) {
(*cntx)->SendDouble(add_result.new_score);
} else {
(*cntx)->SendLong(result.value());
(*cntx)->SendLong(add_result.num_updated);
}
}
void ZSetFamily::ZIncrBy(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendLong(0);
string_view key = ArgS(args, 1);
string_view score_arg = ArgS(args, 2);
ScoredMemberView scored_member;
scored_member.second = ArgS(args, 3);
if (!absl::SimpleAtod(score_arg, &scored_member.first)) {
return (*cntx)->SendError(kInvalidFloatErr);
}
if (isnan(scored_member.first)) {
return (*cntx)->SendError(kScoreNaN);
}
ZParams zparams;
zparams.flags = ZADD_IN_INCR;
AddResult add_result;
auto cb = [&](Transaction* t, EngineShard* shard) -> OpStatus {
OpArgs op_args{shard, t->db_index()};
return OpAdd(zparams, op_args, key, ScoredMemberSpan{&scored_member, 1}, &add_result);
};
OpStatus status = cntx->transaction->ScheduleSingleHop(std::move(cb));
if (status == OpStatus::WRONG_TYPE) {
return (*cntx)->SendError(kWrongTypeErr);
}
if (status == OpStatus::SKIPPED) {
return (*cntx)->SendNull();
}
if (add_result.is_nan) {
return (*cntx)->SendError(kScoreNaN);
}
(*cntx)->SendDouble(add_result.new_score);
}
void ZSetFamily::ZRange(CmdArgList args, ConnectionContext* cntx) {
@ -487,9 +536,9 @@ void ZSetFamily::ZRevRange(CmdArgList args, ConnectionContext* cntx) {
}
void ZSetFamily::ZRangeByScore(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view min_s = ArgS(args, 2);
std::string_view max_s = ArgS(args, 3);
string_view key = ArgS(args, 1);
string_view min_s = ArgS(args, 2);
string_view max_s = ArgS(args, 3);
RangeParams range_params;
@ -508,9 +557,9 @@ void ZSetFamily::ZRangeByScore(CmdArgList args, ConnectionContext* cntx) {
}
void ZSetFamily::ZRemRangeByRank(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view min_s = ArgS(args, 2);
std::string_view max_s = ArgS(args, 3);
string_view key = ArgS(args, 1);
string_view min_s = ArgS(args, 2);
string_view max_s = ArgS(args, 3);
IndexInterval ii;
if (!absl::SimpleAtoi(min_s, &ii.first) || !absl::SimpleAtoi(max_s, &ii.second)) {
@ -523,9 +572,9 @@ void ZSetFamily::ZRemRangeByRank(CmdArgList args, ConnectionContext* cntx) {
}
void ZSetFamily::ZRemRangeByScore(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view min_s = ArgS(args, 2);
std::string_view max_s = ArgS(args, 3);
string_view key = ArgS(args, 1);
string_view min_s = ArgS(args, 2);
string_view max_s = ArgS(args, 3);
ScoreInterval si;
if (!ParseBound(min_s, &si.first) || !ParseBound(max_s, &si.second)) {
@ -540,9 +589,9 @@ void ZSetFamily::ZRemRangeByScore(CmdArgList args, ConnectionContext* cntx) {
}
void ZSetFamily::ZRem(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
string_view key = ArgS(args, 1);
absl::InlinedVector<std::string_view, 8> members(args.size() - 2);
absl::InlinedVector<string_view, 8> members(args.size() - 2);
for (size_t i = 2; i < args.size(); ++i) {
members[i - 2] = ArgS(args, i);
}
@ -561,8 +610,8 @@ void ZSetFamily::ZRem(CmdArgList args, ConnectionContext* cntx) {
}
void ZSetFamily::ZScore(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view member = ArgS(args, 2);
string_view key = ArgS(args, 1);
string_view member = ArgS(args, 2);
auto cb = [&](Transaction* t, EngineShard* shard) {
OpArgs op_args{shard, t->db_index()};
@ -579,9 +628,8 @@ void ZSetFamily::ZScore(CmdArgList args, ConnectionContext* cntx) {
}
}
void ZSetFamily::ZRangeByScoreInternal(std::string_view key, std::string_view min_s,
std::string_view max_s, const RangeParams& params,
ConnectionContext* cntx) {
void ZSetFamily::ZRangeByScoreInternal(string_view key, string_view min_s, string_view max_s,
const RangeParams& params, ConnectionContext* cntx) {
ZRangeSpec range_spec;
range_spec.params = params;
@ -620,7 +668,7 @@ void ZSetFamily::OutputScoredArrayResult(const OpResult<ScoredArray>& result,
}
}
void ZSetFamily::ZRemRangeGeneric(std::string_view key, const ZRangeSpec& range_spec,
void ZSetFamily::ZRemRangeGeneric(string_view key, const ZRangeSpec& range_spec,
ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) {
OpArgs op_args{shard, t->db_index()};
@ -636,9 +684,9 @@ void ZSetFamily::ZRemRangeGeneric(std::string_view key, const ZRangeSpec& range_
}
void ZSetFamily::ZRangeGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view min_s = ArgS(args, 2);
std::string_view max_s = ArgS(args, 3);
string_view key = ArgS(args, 1);
string_view min_s = ArgS(args, 2);
string_view max_s = ArgS(args, 3);
bool parse_score = false;
RangeParams range_params;
@ -682,8 +730,8 @@ void ZSetFamily::ZRangeGeneric(CmdArgList args, bool reverse, ConnectionContext*
OutputScoredArrayResult(result, range_params, cntx);
}
OpResult<unsigned> ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_args, string_view key,
ScoredMemberSpan members) {
OpStatus ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_args, string_view key,
ScoredMemberSpan members, AddResult* add_result) {
DCHECK(!members.empty());
OpResult<MainIterator> res_it =
FindZEntry(zparams.flags, op_args, key, members.front().second.size());
@ -698,16 +746,25 @@ OpResult<unsigned> ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_ar
unsigned processed = 0;
sds& tmp_str = op_args.shard->tmp_str1;
double new_score;
int retflags = 0;
for (size_t j = 0; j < members.size(); j++) {
const auto& m = members[j];
tmp_str = sdscpylen(tmp_str, m.second.data(), m.second.size());
int retflags = 0;
int retval = zsetAdd(zobj, m.first, tmp_str, zparams.flags, &retflags, nullptr);
int retval = zsetAdd(zobj, m.first, tmp_str, zparams.flags, &retflags, &new_score);
if (retval == 0) {
LOG(FATAL) << "unexpected error in zsetAdd: " << m.first;
if (zparams.flags & ZADD_IN_INCR) {
if (retval == 0) {
CHECK_EQ(1u, members.size());
add_result->is_nan = true;
return OpStatus::OK;
}
if (retflags & ZADD_OUT_NOP)
return OpStatus::SKIPPED;
}
if (retflags & ZADD_OUT_ADDED)
@ -721,8 +778,13 @@ OpResult<unsigned> ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_ar
DVLOG(2) << "ZAdd " << zobj->ptr;
res_it.value()->second.SyncRObj();
if (zparams.flags & ZADD_IN_INCR) {
add_result->new_score = new_score;
} else {
add_result->num_updated = zparams.ch ? added + updated : added;
}
return zparams.ch ? added + updated : added;
return OpStatus::OK;
}
OpResult<unsigned> ZSetFamily::OpRem(const OpArgs& op_args, string_view key, ArgSlice members) {
@ -763,7 +825,7 @@ OpResult<double> ZSetFamily::OpScore(const OpArgs& op_args, string_view key, str
return score;
}
auto ZSetFamily::OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, std::string_view key)
auto ZSetFamily::OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, string_view key)
-> OpResult<ScoredArray> {
OpResult<MainIterator> res_it = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_ZSET);
if (!res_it)

View File

@ -73,8 +73,15 @@ class ZSetFamily {
using ScoredMemberView = std::pair<double, std::string_view>;
using ScoredMemberSpan = absl::Span<ScoredMemberView>;
static OpResult<unsigned> OpAdd(const ZParams& zparams, const OpArgs& op_args,
std::string_view key, ScoredMemberSpan members);
struct AddResult {
double new_score;
unsigned num_updated = 0;
bool is_nan = false;
};
static facade::OpStatus OpAdd(const ZParams& zparams, const OpArgs& op_args, std::string_view key,
ScoredMemberSpan members, AddResult* add_result);
static OpResult<unsigned> OpRem(const OpArgs& op_args, std::string_view key, ArgSlice members);
static OpResult<double> OpScore(const OpArgs& op_args, std::string_view key,
std::string_view member);

View File

@ -79,4 +79,15 @@ TEST_F(ZSetFamilyTest, ZRemRangeScore) {
EXPECT_THAT(Run({"type", "x"}), ElementsAre("none"));
}
TEST_F(ZSetFamilyTest, IncrBy) {
auto resp = Run({"zadd", "key", "xx", "incr", "2.1", "member"});
EXPECT_THAT(resp[0], ArgType(RespExpr::NIL));
resp = Run({"zadd", "key", "nx", "incr", "2.1", "member"});
EXPECT_THAT(resp[0], "2.1");
resp = Run({"zadd", "key", "nx", "incr", "4.9", "member"});
EXPECT_THAT(resp[0], ArgType(RespExpr::NIL));
}
} // namespace dfly