diff --git a/server/main_service.cc b/server/main_service.cc index 7043387..a5a4e76 100644 --- a/server/main_service.cc +++ b/server/main_service.cc @@ -98,6 +98,10 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) { return cntx->SendError(WrongNumArgsError(cmd_str)); } + if (cid->key_arg_step() == 2 && (args.size() % 2) == 0) { + return cntx->SendError(WrongNumArgsError(cmd_str)); + } + uint64_t start_usec = ProactorBase::GetMonotonicTimeNs(), end_usec; // Create command transaction @@ -171,7 +175,7 @@ void Service::RegisterHttp(HttpListenerBase* listener) { void Service::Ping(CmdArgList args, ConnectionContext* cntx) { if (args.size() > 2) { - return cntx->SendError("wrong number of arguments for 'ping' command"); + return cntx->SendError(WrongNumArgsError("PING")); } ping_qps.Inc(); diff --git a/server/reply_builder.cc b/server/reply_builder.cc index 3feaa1e..a5a840d 100644 --- a/server/reply_builder.cc +++ b/server/reply_builder.cc @@ -177,6 +177,20 @@ void ReplyBuilder::SendGetNotFound() { } } +void ReplyBuilder::SendMGetResponse(const StrOrNil* arr, uint32_t count) { + string res = absl::StrCat("*", count, kCRLF); + for (size_t i = 0; i < count; ++i) { + if (arr[i]) { + StrAppend(&res, "$", arr[i]->size(), kCRLF); + res.append(*arr[i]).append(kCRLF); + } else { + res.append("$-1\r\n"); + } + } + + as_resp()->SendDirect(res); +} + void ReplyBuilder::SendSimpleStrArr(const std::string_view* arr, uint32_t count) { CHECK(protocol_ == Protocol::REDIS); string res = absl::StrCat("*", count, kCRLF); diff --git a/server/reply_builder.h b/server/reply_builder.h index 3dc3e9f..7ca38f1 100644 --- a/server/reply_builder.h +++ b/server/reply_builder.h @@ -1,6 +1,7 @@ // Copyright 2021, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // +#include #include #include "core/op_status.h" @@ -12,6 +13,8 @@ namespace dfly { class BaseSerializer { public: explicit BaseSerializer(::io::Sink* sink); + virtual ~BaseSerializer() { + } std::error_code ec() const { return ec_; @@ -107,6 +110,9 @@ class ReplyBuilder { as_resp()->SendNull(); } + using StrOrNil = std::optional; + void SendMGetResponse(const StrOrNil* arr, uint32_t count); + private: RespSerializer* as_resp() { return static_cast(serializer_.get()); diff --git a/server/string_family.cc b/server/string_family.cc index ffe84f3..3692b90 100644 --- a/server/string_family.cc +++ b/server/string_family.cc @@ -189,6 +189,96 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) { return cntx->SendNull(); } +void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) { + DCHECK_GT(args.size(), 1U); + + Transaction* transaction = cntx->transaction; + unsigned shard_count = transaction->shard_set()->size(); + std::vector mget_resp(shard_count); + + auto cb = [&](Transaction* t, EngineShard* shard) { + ShardId sid = shard->shard_id(); + mget_resp[sid] = OpMGet(t, shard); + return OpStatus::OK; + }; + + // MGet requires locking as well. For example, if coordinator A applied W(x) and then W(y) + // it necessarily means that whoever observed y, must observe x. + // Without locking, mget x y could read stale x but latest y. + OpStatus result = transaction->ScheduleSingleHop(std::move(cb)); + CHECK_EQ(OpStatus::OK, result); + + // reorder the responses back according to the order of their corresponding keys. + vector> res(args.size() - 1); + for (ShardId sid = 0; sid < shard_count; ++sid) { + if (!transaction->IsActive(sid)) + continue; + auto& values = mget_resp[sid]; + ArgSlice slice = transaction->ShardArgsInShard(sid); + DCHECK(!slice.empty()); + DCHECK_EQ(slice.size(), values.size()); + for (size_t j = 0; j < slice.size(); ++j) { + uint32_t indx = transaction->ReverseArgIndex(sid, j); + res[indx] = values[j]; + } + } + + return cntx->SendMGetResponse(res.data(), res.size()); +} + +void StringFamily::MSet(CmdArgList args, ConnectionContext* cntx) { + Transaction* transaction = cntx->transaction; + + if (VLOG_IS_ON(2)) { + string str; + for (size_t i = 1; i < args.size(); ++i) { + absl::StrAppend(&str, " ", ArgS(args, i)); + } + LOG(INFO) << "MSET/" << transaction->unique_shard_cnt() << str; + } + + OpStatus status = transaction->ScheduleSingleHop(&OpMSet); + CHECK_EQ(OpStatus::OK, status); + + DVLOG(2) << "MSet run " << transaction->DebugId(); + + return cntx->SendOk(); +} + + +auto StringFamily::OpMGet(const Transaction* t, EngineShard* shard) -> MGetResponse { + auto args = t->ShardArgsInShard(shard->shard_id()); + DCHECK(!args.empty()); + + MGetResponse response(args.size()); + + auto& db_slice = shard->db_slice(); + for (size_t i = 0; i < args.size(); ++i) { + OpResult de_res = db_slice.Find(0, args[i]); + if (de_res.ok()) { + response[i] = de_res.value()->second.str; + } + } + + return response; +} + +OpStatus StringFamily::OpMSet(const Transaction* t, EngineShard* es) { + ArgSlice largs = t->ShardArgsInShard(es->shard_id()); + CHECK(!largs.empty() && largs.size() % 2 == 0); + + SetCmd::SetParams params{0}; + SetCmd sg(&es->db_slice()); + for (size_t i = 0; i < largs.size(); i += 2) { + DVLOG(1) << "MSet " << largs[i] << ":" << largs[i + 1]; + auto res = sg.Set(params, largs[i], largs[i + 1]); + CHECK(res.ok()) << res << " " << largs[i]; // TODO - handle OOM etc. + } + + return OpStatus::OK; +} + + void StringFamily::Init(util::ProactorPool* pp) { set_qps.Init(pp); get_qps.Init(pp); @@ -204,7 +294,9 @@ void StringFamily::Shutdown() { void StringFamily::Register(CommandRegistry* registry) { *registry << CI{"SET", CO::WRITE | CO::DENYOOM, -3, 1, 1, 1}.HFUNC(Set) << CI{"GET", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(Get) - << CI{"GETSET", CO::WRITE | CO::DENYOOM | CO::FAST, 3, 1, 1, 1}.HFUNC(GetSet); + << CI{"GETSET", CO::WRITE | CO::DENYOOM | CO::FAST, 3, 1, 1, 1}.HFUNC(GetSet) + << CI{"MGET", CO::READONLY | CO::FAST, -2, 1, -1, 1}.HFUNC(MGet) + << CI{"MSET", CO::WRITE | CO::DENYOOM, -3, 1, -1, 2}.HFUNC(MSet); } } // namespace dfly diff --git a/server/string_family.h b/server/string_family.h index b1dd4f8..eaba08c 100644 --- a/server/string_family.h +++ b/server/string_family.h @@ -4,8 +4,8 @@ #pragma once -#include "server/engine_shard_set.h" #include "server/common_types.h" +#include "server/engine_shard_set.h" #include "util/proactor_pool.h" namespace dfly { @@ -13,7 +13,6 @@ namespace dfly { class ConnectionContext; class CommandRegistry; - class SetCmd { DbSlice* db_slice_; @@ -30,7 +29,7 @@ class SetCmd { // Relative value based on now. 0 means no expiration. uint64_t expire_after_ms = 0; mutable std::optional* prev_val = nullptr; // GETSET option - bool keep_expire = false; // KEEPTTL - TODO: to implement it. + bool keep_expire = false; // KEEPTTL - TODO: to implement it. explicit SetParams(DbIndex dib) : db_index(dib) { } @@ -54,6 +53,14 @@ class StringFamily { static void Set(CmdArgList args, ConnectionContext* cntx); static void Get(CmdArgList args, ConnectionContext* cntx); static void GetSet(CmdArgList args, ConnectionContext* cntx); + static void MGet(CmdArgList args, ConnectionContext* cntx); + static void MSet(CmdArgList args, ConnectionContext* cntx); + + using MGetResponse = std::vector>; + + static MGetResponse OpMGet(const Transaction* t, EngineShard* shard); + + static OpStatus OpMSet(const Transaction* t, EngineShard* es); }; } // namespace dfly diff --git a/server/transaction.cc b/server/transaction.cc index 2ce0f9e..c48e6e7 100644 --- a/server/transaction.cc +++ b/server/transaction.cc @@ -89,7 +89,7 @@ void Transaction::InitByArgs(CmdArgList args) { } CHECK(cid_->key_arg_step() == 1 || cid_->key_arg_step() == 2); - CHECK(cid_->key_arg_step() == 1 || (args.size() % 2) == 1); + DCHECK(cid_->key_arg_step() == 1 || (args.size() % 2) == 1); // Reuse thread-local temporary storage. Since this code is non-preemptive we can use it here. auto& shard_index = tmp_space.shard_cache;