Add more list commands

This commit is contained in:
Roman Gershman 2022-03-15 19:52:11 +02:00
parent f9b70125d6
commit 29c715fda5
10 changed files with 323 additions and 30 deletions

View File

@ -97,10 +97,10 @@ API 1.0
- [X] LLEN - [X] LLEN
- [X] LPOP - [X] LPOP
- [X] LPUSH - [X] LPUSH
- [ ] LRANGE - [X] LRANGE
- [ ] LREM - [X] LREM
- [ ] LSET - [X] LSET
- [ ] LTRIM - [X] LTRIM
- [X] RPOP - [X] RPOP
- [ ] RPOPLPUSH - [ ] RPOPLPUSH
- [X] RPUSH - [X] RPUSH

View File

@ -255,6 +255,16 @@ void RedisReplyBuilder::SendStringArr(absl::Span<const std::string_view> arr) {
SendDirect(res); SendDirect(res);
} }
void RedisReplyBuilder::SendStringArr(absl::Span<const string> arr) {
string res = absl::StrCat("*", arr.size(), kCRLF);
for (size_t i = 0; i < arr.size(); ++i) {
StrAppend(&res, "$", arr[i].size(), kCRLF);
res.append(arr[i]).append(kCRLF);
}
SendDirect(res);
}
void RedisReplyBuilder::StartArray(unsigned len) { void RedisReplyBuilder::StartArray(unsigned len) {
SendDirect(absl::StrCat("*", len, kCRLF)); SendDirect(absl::StrCat("*", len, kCRLF));
} }

View File

@ -130,6 +130,7 @@ class RedisReplyBuilder : public SinkReplyBuilder {
virtual void SendNullArray(); virtual void SendNullArray();
virtual void SendStringArr(absl::Span<const std::string_view> arr); virtual void SendStringArr(absl::Span<const std::string_view> arr);
virtual void SendStringArr(absl::Span<const std::string> arr);
virtual void SendNull(); virtual void SendNull();
virtual void SendDouble(double val); virtual void SendDouble(double val);

View File

@ -29,6 +29,7 @@ using facade::CmdArgVec;
using facade::ArgS; using facade::ArgS;
using ArgSlice = absl::Span<const std::string_view>; using ArgSlice = absl::Span<const std::string_view>;
using StringVec = std::vector<std::string>;
constexpr DbIndex kInvalidDbId = DbIndex(-1); constexpr DbIndex kInvalidDbId = DbIndex(-1);
constexpr ShardId kInvalidSid = ShardId(-1); constexpr ShardId kInvalidSid = ShardId(-1);

View File

@ -248,6 +248,92 @@ void ListFamily::LIndex(CmdArgList args, ConnectionContext* cntx) {
} }
} }
void ListFamily::LTrim(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view s_str = ArgS(args, 2);
std::string_view e_str = ArgS(args, 3);
int32_t start, end;
if (!absl::SimpleAtoi(s_str, &start) || !absl::SimpleAtoi(e_str, &end)) {
(*cntx)->SendError(kInvalidIntErr);
return;
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpTrim(OpArgs{shard, t->db_index()}, key, start, end);
};
cntx->transaction->ScheduleSingleHop(std::move(cb));
(*cntx)->SendOk();
}
void ListFamily::LRange(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view s_str = ArgS(args, 2);
std::string_view e_str = ArgS(args, 3);
int32_t start, end;
if (!absl::SimpleAtoi(s_str, &start) || !absl::SimpleAtoi(e_str, &end)) {
(*cntx)->SendError(kInvalidIntErr);
return;
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpRange(OpArgs{shard, t->db_index()}, key, start, end);
};
auto res = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (!res && res.status() != OpStatus::KEY_NOTFOUND) {
return (*cntx)->SendError(res.status());
}
(*cntx)->SendStringArr(*res);
}
// lrem key 5 foo, will remove foo elements from the list if exists at most 5 times.
void ListFamily::LRem(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view index_str = ArgS(args, 2);
std::string_view elem = ArgS(args, 3);
int32_t count;
if (!absl::SimpleAtoi(index_str, &count)) {
(*cntx)->SendError(kInvalidIntErr);
return;
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpRem(OpArgs{shard, t->db_index()}, key, elem, count);
};
OpResult<uint32_t> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result) {
(*cntx)->SendLong(result.value());
} else {
(*cntx)->SendLong(0);
}
}
void ListFamily::LSet(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view index_str = ArgS(args, 2);
std::string_view elem = ArgS(args, 3);
int32_t count;
if (!absl::SimpleAtoi(index_str, &count)) {
(*cntx)->SendError(kInvalidIntErr);
return;
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpSet(OpArgs{shard, t->db_index()}, key, elem, count);
};
OpResult<void> result = cntx->transaction->ScheduleSingleHop(std::move(cb));
if (result) {
(*cntx)->SendOk();
} else {
(*cntx)->SendError(result.status());
}
}
void ListFamily::BLPop(CmdArgList args, ConnectionContext* cntx) { void ListFamily::BLPop(CmdArgList args, ConnectionContext* cntx) {
DCHECK_GE(args.size(), 3u); DCHECK_GE(args.size(), 3u);
@ -401,16 +487,156 @@ OpResult<string> ListFamily::OpIndex(const OpArgs& op_args, std::string_view key
return res.status(); return res.status();
quicklist* ql = GetQL(res.value()->second); quicklist* ql = GetQL(res.value()->second);
quicklistEntry entry = QLEntry(); quicklistEntry entry = QLEntry();
quicklistIter* iter = quicklistGetIteratorEntryAtIdx(ql, index, &entry); quicklistIter* iter = quicklistGetIteratorAtIdx(ql, AL_START_TAIL, index);
if (!iter) if (!iter)
return OpStatus::KEY_NOTFOUND; return OpStatus::KEY_NOTFOUND;
quicklistNext(iter, &entry);
string str;
if (entry.value) { if (entry.value) {
return string{reinterpret_cast<char*>(entry.value), entry.sz}; str.assign(reinterpret_cast<char*>(entry.value), entry.sz);
} else { } else {
return absl::StrCat(entry.longval); str = absl::StrCat(entry.longval);
} }
quicklistReleaseIterator(iter);
return str;
}
OpResult<uint32_t> ListFamily::OpRem(const OpArgs& op_args, std::string_view key,
std::string_view elem, long count) {
DCHECK(!elem.empty());
auto res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_LIST);
if (!res)
return res.status();
quicklist* ql = res.value()->second.GetQL();
int iter_direction = AL_START_HEAD;
long long index = 0;
if (count < 0) {
count = -count;
iter_direction = AL_START_TAIL;
index = -1;
}
quicklistIter* qiter = quicklistGetIteratorAtIdx(ql, iter_direction, index);
quicklistEntry entry;
unsigned removed = 0;
const uint8_t* elem_ptr = reinterpret_cast<const uint8_t*>(elem.data());
while (quicklistNext(qiter, &entry)) {
if (quicklistCompare(&entry, elem_ptr, elem.size())) {
quicklistDelEntry(qiter, &entry);
removed++;
if (count && removed == count)
break;
}
}
quicklistReleaseIterator(qiter);
if (quicklistCount(ql) == 0) {
CHECK(op_args.shard->db_slice().Del(op_args.db_ind, res.value()));
}
return removed;
}
OpStatus ListFamily::OpSet(const OpArgs& op_args, std::string_view key, std::string_view elem,
long index) {
DCHECK(!elem.empty());
auto res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_LIST);
if (!res)
return res.status();
quicklist* ql = res.value()->second.GetQL();
int replaced = quicklistReplaceAtIndex(ql, index, elem.data(), elem.size());
if (!replaced) {
return OpStatus::OUT_OF_RANGE;
}
return OpStatus::OK;
}
OpStatus ListFamily::OpTrim(const OpArgs& op_args, std::string_view key, long start, long end) {
auto res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_LIST);
if (!res)
return res.status();
quicklist* ql = res.value()->second.GetQL();
long llen = quicklistCount(ql);
/* convert negative indexes */
if (start < 0)
start = llen + start;
if (end < 0)
end = llen + end;
if (start < 0)
start = 0;
long ltrim, rtrim;
/* Invariant: start >= 0, so this test will be true when end < 0.
* The range is empty when start > end or start >= length. */
if (start > end || start >= llen) {
/* Out of range start or start > end result in empty list */
ltrim = llen;
rtrim = 0;
} else {
if (end >= llen)
end = llen - 1;
ltrim = start;
rtrim = llen - end - 1;
}
quicklistDelRange(ql, 0, ltrim);
quicklistDelRange(ql, -rtrim, rtrim);
if (quicklistCount(ql) == 0) {
CHECK(op_args.shard->db_slice().Del(op_args.db_ind, res.value()));
}
return OpStatus::OK;
}
OpResult<StringVec> ListFamily::OpRange(const OpArgs& op_args, std::string_view key, long start,
long end) {
auto res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_LIST);
if (!res)
return res.status();
quicklist* ql = res.value()->second.GetQL();
long llen = quicklistCount(ql);
/* convert negative indexes */
if (start < 0)
start = llen + start;
if (end < 0)
end = llen + end;
if (start < 0)
start = 0;
/* Invariant: start >= 0, so this test will be true when end < 0.
* The range is empty when start > end or start >= length. */
if (start > end || start >= llen) {
/* Out of range start or start > end result in empty list */
return StringVec{};
}
if (end >= llen)
end = llen - 1;
unsigned lrange = end - start + 1;
quicklistIter* qiter = quicklistGetIteratorAtIdx(ql, AL_START_HEAD, start);
quicklistEntry entry = QLEntry();
StringVec str_vec;
unsigned cnt = 0;
while (cnt < lrange && quicklistNext(qiter, &entry)) {
if (entry.value)
str_vec.emplace_back(reinterpret_cast<char*>(entry.value), entry.sz);
else
str_vec.push_back(absl::StrCat(entry.longval));
++cnt;
}
quicklistReleaseIterator(qiter);
return str_vec;
} }
using CI = CommandId; using CI = CommandId;
@ -424,7 +650,11 @@ void ListFamily::Register(CommandRegistry* registry) {
<< CI{"RPOP", CO::WRITE | CO::FAST | CO::DENYOOM, 2, 1, 1, 1}.HFUNC(RPop) << CI{"RPOP", CO::WRITE | CO::FAST | CO::DENYOOM, 2, 1, 1, 1}.HFUNC(RPop)
<< CI{"BLPOP", CO::WRITE | CO::NOSCRIPT | CO::BLOCKING, -3, 1, -2, 1}.HFUNC(BLPop) << CI{"BLPOP", CO::WRITE | CO::NOSCRIPT | CO::BLOCKING, -3, 1, -2, 1}.HFUNC(BLPop)
<< CI{"LLEN", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(LLen) << CI{"LLEN", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(LLen)
<< CI{"LINDEX", CO::READONLY, 3, 1, 1, 1}.HFUNC(LIndex); << CI{"LINDEX", CO::READONLY, 3, 1, 1, 1}.HFUNC(LIndex)
<< CI{"LRANGE", CO::READONLY, 4, 1, 1, 1}.HFUNC(LRange)
<< CI{"LSET", CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(LSet)
<< CI{"LTRIM", CO::WRITE, 4, 1, 1, 1}.HFUNC(LTrim)
<< CI{"LREM", CO::WRITE, 4, 1, 1, 1}.HFUNC(LRem);
} }
} // namespace dfly } // namespace dfly

View File

@ -27,6 +27,10 @@ class ListFamily {
static void BLPop(CmdArgList args, ConnectionContext* cntx); static void BLPop(CmdArgList args, ConnectionContext* cntx);
static void LLen(CmdArgList args, ConnectionContext* cntx); static void LLen(CmdArgList args, ConnectionContext* cntx);
static void LIndex(CmdArgList args, ConnectionContext* cntx); static void LIndex(CmdArgList args, ConnectionContext* cntx);
static void LTrim(CmdArgList args, ConnectionContext* cntx);
static void LRange(CmdArgList args, ConnectionContext* cntx);
static void LRem(CmdArgList args, ConnectionContext* cntx);
static void LSet(CmdArgList args, ConnectionContext* cntx);
static void PopGeneric(ListDir dir, const CmdArgList& args, ConnectionContext* cntx); static void PopGeneric(ListDir dir, const CmdArgList& args, ConnectionContext* cntx);
static void PushGeneric(ListDir dir, const CmdArgList& args, ConnectionContext* cntx); static void PushGeneric(ListDir dir, const CmdArgList& args, ConnectionContext* cntx);
@ -36,6 +40,15 @@ class ListFamily {
static OpResult<std::string> OpPop(const OpArgs& op_args, std::string_view key, ListDir dir); static OpResult<std::string> OpPop(const OpArgs& op_args, std::string_view key, ListDir dir);
static OpResult<uint32_t> OpLen(const OpArgs& op_args, std::string_view key); static OpResult<uint32_t> OpLen(const OpArgs& op_args, std::string_view key);
static OpResult<std::string> OpIndex(const OpArgs& op_args, std::string_view key, long index); static OpResult<std::string> OpIndex(const OpArgs& op_args, std::string_view key, long index);
static OpResult<uint32_t> OpRem(const OpArgs& op_args, std::string_view key,
std::string_view elem, long count);
static facade::OpStatus OpSet(const OpArgs& op_args, std::string_view key, std::string_view elem,
long count);
static facade::OpStatus OpTrim(const OpArgs& op_args, std::string_view key, long start, long end);
static OpResult<StringVec> OpRange(const OpArgs& op_args, std::string_view key,
long start, long end);
}; };
} // namespace dfly } // namespace dfly

View File

@ -154,6 +154,38 @@ TEST_F(ListFamilyTest, BLPopTimeout) {
ASSERT_FALSE(service_->IsLocked(0, kKey1)); ASSERT_FALSE(service_->IsLocked(0, kKey1));
} }
TEST_F(ListFamilyTest, LRem) {
auto resp = Run({"rpush", kKey1, "a", "b", "a", "c"});
ASSERT_THAT(resp, ElementsAre(IntArg(4)));
resp = Run({"lrem", kKey1, "2", "a"});
ASSERT_THAT(resp, ElementsAre(IntArg(2)));
ASSERT_THAT(Run({"lrange", kKey1, "0", "1"}), ElementsAre("b", "c"));
}
TEST_F(ListFamilyTest, LTrim) {
Run({"rpush", kKey1, "a", "b", "c", "d"});
ASSERT_THAT(Run({"ltrim", kKey1, "-2", "-1"}), RespEq("OK"));
ASSERT_THAT(Run({"lrange", kKey1, "0", "1"}), ElementsAre("c", "d"));
ASSERT_THAT(Run({"ltrim", kKey1, "0", "0"}), RespEq("OK"));
ASSERT_THAT(Run({"lrange", kKey1, "0", "1"}), ElementsAre("c"));
}
TEST_F(ListFamilyTest, LRange) {
auto resp = Run({"lrange", kKey1, "0", "5"});
ASSERT_THAT(resp[0], ArrLen(0));
Run({"rpush", kKey1, "0", "1", "2"});
resp = Run({"lrange", kKey1, "-2", "-1"});
ASSERT_THAT(resp, ElementsAre("1", "2"));
}
TEST_F(ListFamilyTest, Lset) {
Run({"rpush", kKey1, "0", "1", "2"});
ASSERT_THAT(Run({"lset", kKey1, "0", "bar"}), RespEq("OK"));
ASSERT_THAT(Run({"lpop", kKey1}), RespEq("bar"));
ASSERT_THAT(Run({"lset", kKey1, "-1", "foo"}), RespEq("OK"));
ASSERT_THAT(Run({"rpop", kKey1}), RespEq("foo"));
}
TEST_F(ListFamilyTest, BLPopSerialize) { TEST_F(ListFamilyTest, BLPopSerialize) {
RespVec blpop_resp; RespVec blpop_resp;

View File

@ -76,6 +76,7 @@ class InterpreterReplier : public RedisReplyBuilder {
void SendNullArray() final; void SendNullArray() final;
void SendStringArr(absl::Span<const string_view> arr) final; void SendStringArr(absl::Span<const string_view> arr) final;
void SendStringArr(absl::Span<const string> arr) final;
void SendNull() final; void SendNull() final;
void SendLong(long val) final; void SendLong(long val) final;
@ -211,6 +212,15 @@ void InterpreterReplier::SendStringArr(absl::Span<const string_view> arr) {
PostItem(); PostItem();
} }
void InterpreterReplier::SendStringArr(absl::Span<const string> arr) {
explr_->OnArrayStart(arr.size());
for (uint32_t i = 0; i < arr.size(); ++i) {
explr_->OnString(arr[i]);
}
explr_->OnArrayEnd();
PostItem();
}
void InterpreterReplier::SendNull() { void InterpreterReplier::SendNull() {
explr_->OnNil(); explr_->OnNil();
PostItem(); PostItem();

View File

@ -576,7 +576,7 @@ void SetFamily::SPop(CmdArgList args, ConnectionContext* cntx) {
return OpPop(OpArgs{shard, t->db_index()}, key, count); return OpPop(OpArgs{shard, t->db_index()}, key, count);
}; };
OpResult<StringSet> result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); OpResult<StringVec> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result || result.status() == OpStatus::KEY_NOTFOUND) { if (result || result.status() == OpStatus::KEY_NOTFOUND) {
if (args.size() == 2) { // SPOP key if (args.size() == 2) { // SPOP key
if (result.status() == OpStatus::KEY_NOTFOUND) { if (result.status() == OpStatus::KEY_NOTFOUND) {
@ -586,8 +586,7 @@ void SetFamily::SPop(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendBulkString(result.value().front()); (*cntx)->SendBulkString(result.value().front());
} }
} else { // SPOP key cnt } else { // SPOP key cnt
SvArray vec{result->begin(), result->end()}; (*cntx)->SendStringArr(*result);
(*cntx)->SendStringArr(vec);
} }
return; return;
} }
@ -678,14 +677,15 @@ void SetFamily::SDiffStore(CmdArgList args, ConnectionContext* cntx) {
void SetFamily::SMembers(CmdArgList args, ConnectionContext* cntx) { void SetFamily::SMembers(CmdArgList args, ConnectionContext* cntx) {
auto cb = [](Transaction* t, EngineShard* shard) { return OpInter(t, shard, false); }; auto cb = [](Transaction* t, EngineShard* shard) { return OpInter(t, shard, false); };
OpResult<StringSet> result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); OpResult<StringVec> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result || result.status() == OpStatus::KEY_NOTFOUND) { if (result || result.status() == OpStatus::KEY_NOTFOUND) {
SvArray arr{result->begin(), result->end()}; StringVec& svec = result.value();
if (cntx->conn_state.script_info) { // sort under script if (cntx->conn_state.script_info) { // sort under script
sort(arr.begin(), arr.end()); sort(svec.begin(), svec.end());
} }
(*cntx)->SendStringArr(arr); (*cntx)->SendStringArr(*result);
} else { } else {
(*cntx)->SendError(result.status()); (*cntx)->SendError(result.status());
} }
@ -817,7 +817,7 @@ void SetFamily::SUnionStore(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendLong(result.size()); (*cntx)->SendLong(result.size());
} }
auto SetFamily::OpUnion(const OpArgs& op_args, const ArgSlice& keys) -> OpResult<StringSet> { OpResult<StringVec> SetFamily::OpUnion(const OpArgs& op_args, const ArgSlice& keys) {
DCHECK(!keys.empty()); DCHECK(!keys.empty());
absl::flat_hash_set<string> uniques; absl::flat_hash_set<string> uniques;
@ -836,7 +836,7 @@ auto SetFamily::OpUnion(const OpArgs& op_args, const ArgSlice& keys) -> OpResult
return ToVec(std::move(uniques)); return ToVec(std::move(uniques));
} }
auto SetFamily::OpDiff(const Transaction* t, EngineShard* es) -> OpResult<StringSet> { OpResult<StringVec> SetFamily::OpDiff(const Transaction* t, EngineShard* es) {
ArgSlice keys = t->ShardArgsInShard(es->shard_id()); ArgSlice keys = t->ShardArgsInShard(es->shard_id());
DCHECK(!keys.empty()); DCHECK(!keys.empty());
@ -882,14 +882,13 @@ auto SetFamily::OpDiff(const Transaction* t, EngineShard* es) -> OpResult<String
return ToVec(std::move(uniques)); return ToVec(std::move(uniques));
} }
auto SetFamily::OpPop(const OpArgs& op_args, std::string_view key, unsigned count) OpResult<StringVec> SetFamily::OpPop(const OpArgs& op_args, std::string_view key, unsigned count) {
-> OpResult<StringSet> {
auto* es = op_args.shard; auto* es = op_args.shard;
OpResult<MainIterator> find_res = es->db_slice().Find(op_args.db_ind, key, OBJ_SET); OpResult<MainIterator> find_res = es->db_slice().Find(op_args.db_ind, key, OBJ_SET);
if (!find_res) if (!find_res)
return find_res.status(); return find_res.status();
StringSet result; StringVec result;
if (count == 0) if (count == 0)
return result; return result;
@ -935,15 +934,14 @@ auto SetFamily::OpPop(const OpArgs& op_args, std::string_view key, unsigned coun
return result; return result;
} }
auto SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first) OpResult<StringVec> SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first) {
-> OpResult<StringSet> {
ArgSlice keys = t->ShardArgsInShard(es->shard_id()); ArgSlice keys = t->ShardArgsInShard(es->shard_id());
if (remove_first) { if (remove_first) {
keys.remove_prefix(1); keys.remove_prefix(1);
} }
DCHECK(!keys.empty()); DCHECK(!keys.empty());
StringSet result; StringVec result;
if (keys.size() == 1) { if (keys.size() == 1) {
OpResult<MainIterator> find_res = es->db_slice().Find(t->db_index(), keys.front(), OBJ_SET); OpResult<MainIterator> find_res = es->db_slice().Find(t->db_index(), keys.front(), OBJ_SET);
if (!find_res) if (!find_res)

View File

@ -35,14 +35,12 @@ class SetFamily {
static void SInterStore(CmdArgList args, ConnectionContext* cntx); static void SInterStore(CmdArgList args, ConnectionContext* cntx);
using StringSet = std::vector<std::string>; static OpResult<StringVec> OpUnion(const OpArgs& op_args, const ArgSlice& args);
static OpResult<StringVec> OpDiff(const Transaction* t, EngineShard* es);
static OpResult<StringSet> OpUnion(const OpArgs& op_args, const ArgSlice& args); static OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_first);
static OpResult<StringSet> OpDiff(const Transaction* t, EngineShard* es);
static OpResult<StringSet> OpInter(const Transaction* t, EngineShard* es, bool remove_first);
// count - how many elements to pop. // count - how many elements to pop.
static OpResult<StringSet> OpPop(const OpArgs& op_args, std::string_view key, unsigned count); static OpResult<StringVec> OpPop(const OpArgs& op_args, std::string_view key, unsigned count);
}; };