diff --git a/README.md b/README.md index 05d0171..d5761bd 100644 --- a/README.md +++ b/README.md @@ -97,10 +97,10 @@ API 1.0 - [X] LLEN - [X] LPOP - [X] LPUSH - - [ ] LRANGE - - [ ] LREM - - [ ] LSET - - [ ] LTRIM + - [X] LRANGE + - [X] LREM + - [X] LSET + - [X] LTRIM - [X] RPOP - [ ] RPOPLPUSH - [X] RPUSH diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index b796426..f79e88e 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -255,6 +255,16 @@ void RedisReplyBuilder::SendStringArr(absl::Span arr) { SendDirect(res); } +void RedisReplyBuilder::SendStringArr(absl::Span arr) { + string res = absl::StrCat("*", arr.size(), kCRLF); + + for (size_t i = 0; i < arr.size(); ++i) { + StrAppend(&res, "$", arr[i].size(), kCRLF); + res.append(arr[i]).append(kCRLF); + } + SendDirect(res); +} + void RedisReplyBuilder::StartArray(unsigned len) { SendDirect(absl::StrCat("*", len, kCRLF)); } diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index 1bf2856..57370fc 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -130,6 +130,7 @@ class RedisReplyBuilder : public SinkReplyBuilder { virtual void SendNullArray(); virtual void SendStringArr(absl::Span arr); + virtual void SendStringArr(absl::Span arr); virtual void SendNull(); virtual void SendDouble(double val); diff --git a/src/server/common_types.h b/src/server/common_types.h index 79164c4..5b27259 100644 --- a/src/server/common_types.h +++ b/src/server/common_types.h @@ -29,6 +29,7 @@ using facade::CmdArgVec; using facade::ArgS; using ArgSlice = absl::Span; +using StringVec = std::vector; constexpr DbIndex kInvalidDbId = DbIndex(-1); constexpr ShardId kInvalidSid = ShardId(-1); diff --git a/src/server/list_family.cc b/src/server/list_family.cc index 045333f..5bb9481 100644 --- a/src/server/list_family.cc +++ b/src/server/list_family.cc @@ -248,6 +248,92 @@ void ListFamily::LIndex(CmdArgList args, ConnectionContext* cntx) { } } +void ListFamily::LTrim(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + std::string_view s_str = ArgS(args, 2); + std::string_view e_str = ArgS(args, 3); + int32_t start, end; + + if (!absl::SimpleAtoi(s_str, &start) || !absl::SimpleAtoi(e_str, &end)) { + (*cntx)->SendError(kInvalidIntErr); + return; + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpTrim(OpArgs{shard, t->db_index()}, key, start, end); + }; + cntx->transaction->ScheduleSingleHop(std::move(cb)); + (*cntx)->SendOk(); +} + +void ListFamily::LRange(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + std::string_view s_str = ArgS(args, 2); + std::string_view e_str = ArgS(args, 3); + int32_t start, end; + + if (!absl::SimpleAtoi(s_str, &start) || !absl::SimpleAtoi(e_str, &end)) { + (*cntx)->SendError(kInvalidIntErr); + return; + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpRange(OpArgs{shard, t->db_index()}, key, start, end); + }; + + auto res = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (!res && res.status() != OpStatus::KEY_NOTFOUND) { + return (*cntx)->SendError(res.status()); + } + + (*cntx)->SendStringArr(*res); +} + +// lrem key 5 foo, will remove foo elements from the list if exists at most 5 times. +void ListFamily::LRem(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + std::string_view index_str = ArgS(args, 2); + std::string_view elem = ArgS(args, 3); + int32_t count; + + if (!absl::SimpleAtoi(index_str, &count)) { + (*cntx)->SendError(kInvalidIntErr); + return; + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpRem(OpArgs{shard, t->db_index()}, key, elem, count); + }; + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (result) { + (*cntx)->SendLong(result.value()); + } else { + (*cntx)->SendLong(0); + } +} + +void ListFamily::LSet(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + std::string_view index_str = ArgS(args, 2); + std::string_view elem = ArgS(args, 3); + int32_t count; + + if (!absl::SimpleAtoi(index_str, &count)) { + (*cntx)->SendError(kInvalidIntErr); + return; + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpSet(OpArgs{shard, t->db_index()}, key, elem, count); + }; + OpResult result = cntx->transaction->ScheduleSingleHop(std::move(cb)); + if (result) { + (*cntx)->SendOk(); + } else { + (*cntx)->SendError(result.status()); + } +} + void ListFamily::BLPop(CmdArgList args, ConnectionContext* cntx) { DCHECK_GE(args.size(), 3u); @@ -401,16 +487,156 @@ OpResult ListFamily::OpIndex(const OpArgs& op_args, std::string_view key return res.status(); quicklist* ql = GetQL(res.value()->second); quicklistEntry entry = QLEntry(); - quicklistIter* iter = quicklistGetIteratorEntryAtIdx(ql, index, &entry); - + quicklistIter* iter = quicklistGetIteratorAtIdx(ql, AL_START_TAIL, index); if (!iter) return OpStatus::KEY_NOTFOUND; + quicklistNext(iter, &entry); + string str; + if (entry.value) { - return string{reinterpret_cast(entry.value), entry.sz}; + str.assign(reinterpret_cast(entry.value), entry.sz); } else { - return absl::StrCat(entry.longval); + str = absl::StrCat(entry.longval); } + quicklistReleaseIterator(iter); + + return str; +} + +OpResult ListFamily::OpRem(const OpArgs& op_args, std::string_view key, + std::string_view elem, long count) { + DCHECK(!elem.empty()); + auto res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_LIST); + if (!res) + return res.status(); + quicklist* ql = res.value()->second.GetQL(); + + int iter_direction = AL_START_HEAD; + long long index = 0; + if (count < 0) { + count = -count; + iter_direction = AL_START_TAIL; + index = -1; + } + + quicklistIter* qiter = quicklistGetIteratorAtIdx(ql, iter_direction, index); + quicklistEntry entry; + unsigned removed = 0; + const uint8_t* elem_ptr = reinterpret_cast(elem.data()); + while (quicklistNext(qiter, &entry)) { + if (quicklistCompare(&entry, elem_ptr, elem.size())) { + quicklistDelEntry(qiter, &entry); + removed++; + if (count && removed == count) + break; + } + } + quicklistReleaseIterator(qiter); + + if (quicklistCount(ql) == 0) { + CHECK(op_args.shard->db_slice().Del(op_args.db_ind, res.value())); + } + + return removed; +} + +OpStatus ListFamily::OpSet(const OpArgs& op_args, std::string_view key, std::string_view elem, + long index) { + DCHECK(!elem.empty()); + auto res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_LIST); + if (!res) + return res.status(); + quicklist* ql = res.value()->second.GetQL(); + + int replaced = quicklistReplaceAtIndex(ql, index, elem.data(), elem.size()); + if (!replaced) { + return OpStatus::OUT_OF_RANGE; + } + return OpStatus::OK; +} + +OpStatus ListFamily::OpTrim(const OpArgs& op_args, std::string_view key, long start, long end) { + auto res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_LIST); + if (!res) + return res.status(); + quicklist* ql = res.value()->second.GetQL(); + long llen = quicklistCount(ql); + + /* convert negative indexes */ + if (start < 0) + start = llen + start; + if (end < 0) + end = llen + end; + if (start < 0) + start = 0; + + long ltrim, rtrim; + + /* Invariant: start >= 0, so this test will be true when end < 0. + * The range is empty when start > end or start >= length. */ + if (start > end || start >= llen) { + /* Out of range start or start > end result in empty list */ + ltrim = llen; + rtrim = 0; + } else { + if (end >= llen) + end = llen - 1; + ltrim = start; + rtrim = llen - end - 1; + } + quicklistDelRange(ql, 0, ltrim); + quicklistDelRange(ql, -rtrim, rtrim); + + if (quicklistCount(ql) == 0) { + CHECK(op_args.shard->db_slice().Del(op_args.db_ind, res.value())); + } + return OpStatus::OK; +} + +OpResult ListFamily::OpRange(const OpArgs& op_args, std::string_view key, long start, + long end) { + auto res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_LIST); + if (!res) + return res.status(); + + quicklist* ql = res.value()->second.GetQL(); + long llen = quicklistCount(ql); + + /* convert negative indexes */ + if (start < 0) + start = llen + start; + if (end < 0) + end = llen + end; + if (start < 0) + start = 0; + + /* Invariant: start >= 0, so this test will be true when end < 0. + * The range is empty when start > end or start >= length. */ + if (start > end || start >= llen) { + /* Out of range start or start > end result in empty list */ + return StringVec{}; + } + + if (end >= llen) + end = llen - 1; + + unsigned lrange = end - start + 1; + quicklistIter* qiter = quicklistGetIteratorAtIdx(ql, AL_START_HEAD, start); + quicklistEntry entry = QLEntry(); + StringVec str_vec; + + unsigned cnt = 0; + while (cnt < lrange && quicklistNext(qiter, &entry)) { + if (entry.value) + str_vec.emplace_back(reinterpret_cast(entry.value), entry.sz); + else + str_vec.push_back(absl::StrCat(entry.longval)); + ++cnt; + } + quicklistReleaseIterator(qiter); + + return str_vec; } using CI = CommandId; @@ -424,7 +650,11 @@ void ListFamily::Register(CommandRegistry* registry) { << CI{"RPOP", CO::WRITE | CO::FAST | CO::DENYOOM, 2, 1, 1, 1}.HFUNC(RPop) << CI{"BLPOP", CO::WRITE | CO::NOSCRIPT | CO::BLOCKING, -3, 1, -2, 1}.HFUNC(BLPop) << CI{"LLEN", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(LLen) - << CI{"LINDEX", CO::READONLY, 3, 1, 1, 1}.HFUNC(LIndex); + << CI{"LINDEX", CO::READONLY, 3, 1, 1, 1}.HFUNC(LIndex) + << CI{"LRANGE", CO::READONLY, 4, 1, 1, 1}.HFUNC(LRange) + << CI{"LSET", CO::WRITE | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(LSet) + << CI{"LTRIM", CO::WRITE, 4, 1, 1, 1}.HFUNC(LTrim) + << CI{"LREM", CO::WRITE, 4, 1, 1, 1}.HFUNC(LRem); } } // namespace dfly diff --git a/src/server/list_family.h b/src/server/list_family.h index a9ab7d3..b74bad6 100644 --- a/src/server/list_family.h +++ b/src/server/list_family.h @@ -27,6 +27,10 @@ class ListFamily { static void BLPop(CmdArgList args, ConnectionContext* cntx); static void LLen(CmdArgList args, ConnectionContext* cntx); static void LIndex(CmdArgList args, ConnectionContext* cntx); + static void LTrim(CmdArgList args, ConnectionContext* cntx); + static void LRange(CmdArgList args, ConnectionContext* cntx); + static void LRem(CmdArgList args, ConnectionContext* cntx); + static void LSet(CmdArgList args, ConnectionContext* cntx); static void PopGeneric(ListDir dir, const CmdArgList& args, ConnectionContext* cntx); static void PushGeneric(ListDir dir, const CmdArgList& args, ConnectionContext* cntx); @@ -36,6 +40,15 @@ class ListFamily { static OpResult OpPop(const OpArgs& op_args, std::string_view key, ListDir dir); static OpResult OpLen(const OpArgs& op_args, std::string_view key); static OpResult OpIndex(const OpArgs& op_args, std::string_view key, long index); + + static OpResult OpRem(const OpArgs& op_args, std::string_view key, + std::string_view elem, long count); + static facade::OpStatus OpSet(const OpArgs& op_args, std::string_view key, std::string_view elem, + long count); + static facade::OpStatus OpTrim(const OpArgs& op_args, std::string_view key, long start, long end); + + static OpResult OpRange(const OpArgs& op_args, std::string_view key, + long start, long end); }; } // namespace dfly diff --git a/src/server/list_family_test.cc b/src/server/list_family_test.cc index e236148..ba1c83a 100644 --- a/src/server/list_family_test.cc +++ b/src/server/list_family_test.cc @@ -154,6 +154,38 @@ TEST_F(ListFamilyTest, BLPopTimeout) { ASSERT_FALSE(service_->IsLocked(0, kKey1)); } +TEST_F(ListFamilyTest, LRem) { + auto resp = Run({"rpush", kKey1, "a", "b", "a", "c"}); + ASSERT_THAT(resp, ElementsAre(IntArg(4))); + resp = Run({"lrem", kKey1, "2", "a"}); + ASSERT_THAT(resp, ElementsAre(IntArg(2))); + ASSERT_THAT(Run({"lrange", kKey1, "0", "1"}), ElementsAre("b", "c")); +} + +TEST_F(ListFamilyTest, LTrim) { + Run({"rpush", kKey1, "a", "b", "c", "d"}); + ASSERT_THAT(Run({"ltrim", kKey1, "-2", "-1"}), RespEq("OK")); + ASSERT_THAT(Run({"lrange", kKey1, "0", "1"}), ElementsAre("c", "d")); + ASSERT_THAT(Run({"ltrim", kKey1, "0", "0"}), RespEq("OK")); + ASSERT_THAT(Run({"lrange", kKey1, "0", "1"}), ElementsAre("c")); +} + +TEST_F(ListFamilyTest, LRange) { + auto resp = Run({"lrange", kKey1, "0", "5"}); + ASSERT_THAT(resp[0], ArrLen(0)); + Run({"rpush", kKey1, "0", "1", "2"}); + resp = Run({"lrange", kKey1, "-2", "-1"}); + ASSERT_THAT(resp, ElementsAre("1", "2")); +} + +TEST_F(ListFamilyTest, Lset) { + Run({"rpush", kKey1, "0", "1", "2"}); + ASSERT_THAT(Run({"lset", kKey1, "0", "bar"}), RespEq("OK")); + ASSERT_THAT(Run({"lpop", kKey1}), RespEq("bar")); + ASSERT_THAT(Run({"lset", kKey1, "-1", "foo"}), RespEq("OK")); + ASSERT_THAT(Run({"rpop", kKey1}), RespEq("foo")); +} + TEST_F(ListFamilyTest, BLPopSerialize) { RespVec blpop_resp; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 3f14556..e3c69f6 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -76,6 +76,7 @@ class InterpreterReplier : public RedisReplyBuilder { void SendNullArray() final; void SendStringArr(absl::Span arr) final; + void SendStringArr(absl::Span arr) final; void SendNull() final; void SendLong(long val) final; @@ -211,6 +212,15 @@ void InterpreterReplier::SendStringArr(absl::Span arr) { PostItem(); } +void InterpreterReplier::SendStringArr(absl::Span arr) { + explr_->OnArrayStart(arr.size()); + for (uint32_t i = 0; i < arr.size(); ++i) { + explr_->OnString(arr[i]); + } + explr_->OnArrayEnd(); + PostItem(); +} + void InterpreterReplier::SendNull() { explr_->OnNil(); PostItem(); diff --git a/src/server/set_family.cc b/src/server/set_family.cc index 86265af..de5a1d8 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -576,7 +576,7 @@ void SetFamily::SPop(CmdArgList args, ConnectionContext* cntx) { return OpPop(OpArgs{shard, t->db_index()}, key, count); }; - OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); if (result || result.status() == OpStatus::KEY_NOTFOUND) { if (args.size() == 2) { // SPOP key if (result.status() == OpStatus::KEY_NOTFOUND) { @@ -586,8 +586,7 @@ void SetFamily::SPop(CmdArgList args, ConnectionContext* cntx) { (*cntx)->SendBulkString(result.value().front()); } } else { // SPOP key cnt - SvArray vec{result->begin(), result->end()}; - (*cntx)->SendStringArr(vec); + (*cntx)->SendStringArr(*result); } return; } @@ -678,14 +677,15 @@ void SetFamily::SDiffStore(CmdArgList args, ConnectionContext* cntx) { void SetFamily::SMembers(CmdArgList args, ConnectionContext* cntx) { auto cb = [](Transaction* t, EngineShard* shard) { return OpInter(t, shard, false); }; - OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); if (result || result.status() == OpStatus::KEY_NOTFOUND) { - SvArray arr{result->begin(), result->end()}; + StringVec& svec = result.value(); + if (cntx->conn_state.script_info) { // sort under script - sort(arr.begin(), arr.end()); + sort(svec.begin(), svec.end()); } - (*cntx)->SendStringArr(arr); + (*cntx)->SendStringArr(*result); } else { (*cntx)->SendError(result.status()); } @@ -817,7 +817,7 @@ void SetFamily::SUnionStore(CmdArgList args, ConnectionContext* cntx) { (*cntx)->SendLong(result.size()); } -auto SetFamily::OpUnion(const OpArgs& op_args, const ArgSlice& keys) -> OpResult { +OpResult SetFamily::OpUnion(const OpArgs& op_args, const ArgSlice& keys) { DCHECK(!keys.empty()); absl::flat_hash_set uniques; @@ -836,7 +836,7 @@ auto SetFamily::OpUnion(const OpArgs& op_args, const ArgSlice& keys) -> OpResult return ToVec(std::move(uniques)); } -auto SetFamily::OpDiff(const Transaction* t, EngineShard* es) -> OpResult { +OpResult SetFamily::OpDiff(const Transaction* t, EngineShard* es) { ArgSlice keys = t->ShardArgsInShard(es->shard_id()); DCHECK(!keys.empty()); @@ -882,14 +882,13 @@ auto SetFamily::OpDiff(const Transaction* t, EngineShard* es) -> OpResult OpResult { +OpResult SetFamily::OpPop(const OpArgs& op_args, std::string_view key, unsigned count) { auto* es = op_args.shard; OpResult find_res = es->db_slice().Find(op_args.db_ind, key, OBJ_SET); if (!find_res) return find_res.status(); - StringSet result; + StringVec result; if (count == 0) return result; @@ -935,15 +934,14 @@ auto SetFamily::OpPop(const OpArgs& op_args, std::string_view key, unsigned coun return result; } -auto SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first) - -> OpResult { +OpResult SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first) { ArgSlice keys = t->ShardArgsInShard(es->shard_id()); if (remove_first) { keys.remove_prefix(1); } DCHECK(!keys.empty()); - StringSet result; + StringVec result; if (keys.size() == 1) { OpResult find_res = es->db_slice().Find(t->db_index(), keys.front(), OBJ_SET); if (!find_res) diff --git a/src/server/set_family.h b/src/server/set_family.h index 1523c77..598c379 100644 --- a/src/server/set_family.h +++ b/src/server/set_family.h @@ -35,14 +35,12 @@ class SetFamily { static void SInterStore(CmdArgList args, ConnectionContext* cntx); - using StringSet = std::vector; - - static OpResult OpUnion(const OpArgs& op_args, const ArgSlice& args); - static OpResult OpDiff(const Transaction* t, EngineShard* es); - static OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_first); + static OpResult OpUnion(const OpArgs& op_args, const ArgSlice& args); + static OpResult OpDiff(const Transaction* t, EngineShard* es); + static OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_first); // count - how many elements to pop. - static OpResult OpPop(const OpArgs& op_args, std::string_view key, unsigned count); + static OpResult OpPop(const OpArgs& op_args, std::string_view key, unsigned count); };