Add MGet and MSet commands

This commit is contained in:
Roman Gershman 2021-12-23 15:11:46 +02:00
parent ebd404ff5d
commit b1f32e5ebf
6 changed files with 129 additions and 6 deletions

View File

@ -98,6 +98,10 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
return cntx->SendError(WrongNumArgsError(cmd_str)); return cntx->SendError(WrongNumArgsError(cmd_str));
} }
if (cid->key_arg_step() == 2 && (args.size() % 2) == 0) {
return cntx->SendError(WrongNumArgsError(cmd_str));
}
uint64_t start_usec = ProactorBase::GetMonotonicTimeNs(), end_usec; uint64_t start_usec = ProactorBase::GetMonotonicTimeNs(), end_usec;
// Create command transaction // Create command transaction
@ -171,7 +175,7 @@ void Service::RegisterHttp(HttpListenerBase* listener) {
void Service::Ping(CmdArgList args, ConnectionContext* cntx) { void Service::Ping(CmdArgList args, ConnectionContext* cntx) {
if (args.size() > 2) { if (args.size() > 2) {
return cntx->SendError("wrong number of arguments for 'ping' command"); return cntx->SendError(WrongNumArgsError("PING"));
} }
ping_qps.Inc(); ping_qps.Inc();

View File

@ -177,6 +177,20 @@ void ReplyBuilder::SendGetNotFound() {
} }
} }
void ReplyBuilder::SendMGetResponse(const StrOrNil* arr, uint32_t count) {
string res = absl::StrCat("*", count, kCRLF);
for (size_t i = 0; i < count; ++i) {
if (arr[i]) {
StrAppend(&res, "$", arr[i]->size(), kCRLF);
res.append(*arr[i]).append(kCRLF);
} else {
res.append("$-1\r\n");
}
}
as_resp()->SendDirect(res);
}
void ReplyBuilder::SendSimpleStrArr(const std::string_view* arr, uint32_t count) { void ReplyBuilder::SendSimpleStrArr(const std::string_view* arr, uint32_t count) {
CHECK(protocol_ == Protocol::REDIS); CHECK(protocol_ == Protocol::REDIS);
string res = absl::StrCat("*", count, kCRLF); string res = absl::StrCat("*", count, kCRLF);

View File

@ -1,6 +1,7 @@
// Copyright 2021, Roman Gershman. All rights reserved. // Copyright 2021, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms. // See LICENSE for licensing terms.
// //
#include <optional>
#include <string_view> #include <string_view>
#include "core/op_status.h" #include "core/op_status.h"
@ -12,6 +13,8 @@ namespace dfly {
class BaseSerializer { class BaseSerializer {
public: public:
explicit BaseSerializer(::io::Sink* sink); explicit BaseSerializer(::io::Sink* sink);
virtual ~BaseSerializer() {
}
std::error_code ec() const { std::error_code ec() const {
return ec_; return ec_;
@ -107,6 +110,9 @@ class ReplyBuilder {
as_resp()->SendNull(); as_resp()->SendNull();
} }
using StrOrNil = std::optional<std::string_view>;
void SendMGetResponse(const StrOrNil* arr, uint32_t count);
private: private:
RespSerializer* as_resp() { RespSerializer* as_resp() {
return static_cast<RespSerializer*>(serializer_.get()); return static_cast<RespSerializer*>(serializer_.get());

View File

@ -189,6 +189,96 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) {
return cntx->SendNull(); return cntx->SendNull();
} }
void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) {
DCHECK_GT(args.size(), 1U);
Transaction* transaction = cntx->transaction;
unsigned shard_count = transaction->shard_set()->size();
std::vector<MGetResponse> mget_resp(shard_count);
auto cb = [&](Transaction* t, EngineShard* shard) {
ShardId sid = shard->shard_id();
mget_resp[sid] = OpMGet(t, shard);
return OpStatus::OK;
};
// MGet requires locking as well. For example, if coordinator A applied W(x) and then W(y)
// it necessarily means that whoever observed y, must observe x.
// Without locking, mget x y could read stale x but latest y.
OpStatus result = transaction->ScheduleSingleHop(std::move(cb));
CHECK_EQ(OpStatus::OK, result);
// reorder the responses back according to the order of their corresponding keys.
vector<std::optional<std::string_view>> res(args.size() - 1);
for (ShardId sid = 0; sid < shard_count; ++sid) {
if (!transaction->IsActive(sid))
continue;
auto& values = mget_resp[sid];
ArgSlice slice = transaction->ShardArgsInShard(sid);
DCHECK(!slice.empty());
DCHECK_EQ(slice.size(), values.size());
for (size_t j = 0; j < slice.size(); ++j) {
uint32_t indx = transaction->ReverseArgIndex(sid, j);
res[indx] = values[j];
}
}
return cntx->SendMGetResponse(res.data(), res.size());
}
void StringFamily::MSet(CmdArgList args, ConnectionContext* cntx) {
Transaction* transaction = cntx->transaction;
if (VLOG_IS_ON(2)) {
string str;
for (size_t i = 1; i < args.size(); ++i) {
absl::StrAppend(&str, " ", ArgS(args, i));
}
LOG(INFO) << "MSET/" << transaction->unique_shard_cnt() << str;
}
OpStatus status = transaction->ScheduleSingleHop(&OpMSet);
CHECK_EQ(OpStatus::OK, status);
DVLOG(2) << "MSet run " << transaction->DebugId();
return cntx->SendOk();
}
auto StringFamily::OpMGet(const Transaction* t, EngineShard* shard) -> MGetResponse {
auto args = t->ShardArgsInShard(shard->shard_id());
DCHECK(!args.empty());
MGetResponse response(args.size());
auto& db_slice = shard->db_slice();
for (size_t i = 0; i < args.size(); ++i) {
OpResult<MainIterator> de_res = db_slice.Find(0, args[i]);
if (de_res.ok()) {
response[i] = de_res.value()->second.str;
}
}
return response;
}
OpStatus StringFamily::OpMSet(const Transaction* t, EngineShard* es) {
ArgSlice largs = t->ShardArgsInShard(es->shard_id());
CHECK(!largs.empty() && largs.size() % 2 == 0);
SetCmd::SetParams params{0};
SetCmd sg(&es->db_slice());
for (size_t i = 0; i < largs.size(); i += 2) {
DVLOG(1) << "MSet " << largs[i] << ":" << largs[i + 1];
auto res = sg.Set(params, largs[i], largs[i + 1]);
CHECK(res.ok()) << res << " " << largs[i]; // TODO - handle OOM etc.
}
return OpStatus::OK;
}
void StringFamily::Init(util::ProactorPool* pp) { void StringFamily::Init(util::ProactorPool* pp) {
set_qps.Init(pp); set_qps.Init(pp);
get_qps.Init(pp); get_qps.Init(pp);
@ -204,7 +294,9 @@ void StringFamily::Shutdown() {
void StringFamily::Register(CommandRegistry* registry) { void StringFamily::Register(CommandRegistry* registry) {
*registry << CI{"SET", CO::WRITE | CO::DENYOOM, -3, 1, 1, 1}.HFUNC(Set) *registry << CI{"SET", CO::WRITE | CO::DENYOOM, -3, 1, 1, 1}.HFUNC(Set)
<< CI{"GET", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(Get) << CI{"GET", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(Get)
<< CI{"GETSET", CO::WRITE | CO::DENYOOM | CO::FAST, 3, 1, 1, 1}.HFUNC(GetSet); << CI{"GETSET", CO::WRITE | CO::DENYOOM | CO::FAST, 3, 1, 1, 1}.HFUNC(GetSet)
<< CI{"MGET", CO::READONLY | CO::FAST, -2, 1, -1, 1}.HFUNC(MGet)
<< CI{"MSET", CO::WRITE | CO::DENYOOM, -3, 1, -1, 2}.HFUNC(MSet);
} }
} // namespace dfly } // namespace dfly

View File

@ -4,8 +4,8 @@
#pragma once #pragma once
#include "server/engine_shard_set.h"
#include "server/common_types.h" #include "server/common_types.h"
#include "server/engine_shard_set.h"
#include "util/proactor_pool.h" #include "util/proactor_pool.h"
namespace dfly { namespace dfly {
@ -13,7 +13,6 @@ namespace dfly {
class ConnectionContext; class ConnectionContext;
class CommandRegistry; class CommandRegistry;
class SetCmd { class SetCmd {
DbSlice* db_slice_; DbSlice* db_slice_;
@ -30,7 +29,7 @@ class SetCmd {
// Relative value based on now. 0 means no expiration. // Relative value based on now. 0 means no expiration.
uint64_t expire_after_ms = 0; uint64_t expire_after_ms = 0;
mutable std::optional<std::string>* prev_val = nullptr; // GETSET option mutable std::optional<std::string>* prev_val = nullptr; // GETSET option
bool keep_expire = false; // KEEPTTL - TODO: to implement it. bool keep_expire = false; // KEEPTTL - TODO: to implement it.
explicit SetParams(DbIndex dib) : db_index(dib) { explicit SetParams(DbIndex dib) : db_index(dib) {
} }
@ -54,6 +53,14 @@ class StringFamily {
static void Set(CmdArgList args, ConnectionContext* cntx); static void Set(CmdArgList args, ConnectionContext* cntx);
static void Get(CmdArgList args, ConnectionContext* cntx); static void Get(CmdArgList args, ConnectionContext* cntx);
static void GetSet(CmdArgList args, ConnectionContext* cntx); static void GetSet(CmdArgList args, ConnectionContext* cntx);
static void MGet(CmdArgList args, ConnectionContext* cntx);
static void MSet(CmdArgList args, ConnectionContext* cntx);
using MGetResponse = std::vector<std::optional<std::string>>;
static MGetResponse OpMGet(const Transaction* t, EngineShard* shard);
static OpStatus OpMSet(const Transaction* t, EngineShard* es);
}; };
} // namespace dfly } // namespace dfly

View File

@ -89,7 +89,7 @@ void Transaction::InitByArgs(CmdArgList args) {
} }
CHECK(cid_->key_arg_step() == 1 || cid_->key_arg_step() == 2); CHECK(cid_->key_arg_step() == 1 || cid_->key_arg_step() == 2);
CHECK(cid_->key_arg_step() == 1 || (args.size() % 2) == 1); DCHECK(cid_->key_arg_step() == 1 || (args.size() % 2) == 1);
// Reuse thread-local temporary storage. Since this code is non-preemptive we can use it here. // Reuse thread-local temporary storage. Since this code is non-preemptive we can use it here.
auto& shard_index = tmp_space.shard_cache; auto& shard_index = tmp_space.shard_cache;