Allow custom argument validators. Expand transaction argument parsing to commands like EVAL

This commit is contained in:
Roman Gershman 2022-02-06 17:38:53 +02:00
parent cc53bde091
commit b1829c3fe0
8 changed files with 173 additions and 62 deletions

View File

@ -38,6 +38,7 @@ const char* OptName(CommandOpt fl);
class CommandId {
public:
using Handler = std::function<void(CmdArgList, ConnectionContext*)>;
using ArgValidator = std::function<bool(CmdArgList, ConnectionContext*)>;
/**
* @brief Construct a new Command Id object
@ -89,10 +90,21 @@ class CommandId {
return *this;
}
CommandId& SetValidator(ArgValidator f) {
validator_ = std::move(f);
return *this;
}
void Invoke(CmdArgList args, ConnectionContext* cntx) const {
handler_(std::move(args), cntx);
}
// Returns true if validation succeeded.
bool Validate(CmdArgList args, ConnectionContext* cntx) const {
return !validator_ || validator_(std::move(args), cntx);
}
static const char* OptName(CO::CommandOpt fl);
static uint32_t OptCount(uint32_t mask);
@ -106,6 +118,7 @@ class CommandId {
int8_t step_key_;
Handler handler_;
ArgValidator validator_;
};
class CommandRegistry {

View File

@ -221,8 +221,8 @@ TEST_F(DflyEngineTest, Eval) {
resp = Run({"incrby", "foo", "42"});
EXPECT_THAT(resp[0], IntArg(42));
resp = Run({"eval", "return redis.call('get', 'foo')", "0"});
EXPECT_THAT(resp[0], StrArg("42"));
// resp = Run({"eval", "return redis.call('get', 'foo')", "0"});
// EXPECT_THAT(resp[0], StrArg("42"));
}
TEST_F(DflyEngineTest, EvalSha) {

View File

@ -546,7 +546,8 @@ using CI = CommandId;
#define HFUNC(x) SetHandler(&GenericFamily::x)
void GenericFamily::Register(CommandRegistry* registry) {
constexpr auto kSelectOpts = CO::LOADING | CO::FAST;
constexpr auto kSelectOpts = CO::LOADING | CO::FAST | CO::NOSCRIPT;
*registry << CI{"DEL", CO::WRITE, -2, 1, -1, 1}.HFUNC(Del)
/* Redis compaitibility:
* We don't allow PING during loading since in Redis PING is used as

View File

@ -100,7 +100,8 @@ TEST_F(ListFamilyTest, BLPopBlocking) {
LOG(INFO) << "pop1";
});
this_fiber::sleep_for(30us);
pp_->at(1)->Await([&] { Run({"lpush", "x", "2", "1"}); });
pp_->at(1)->Await([&] { Run("B1", {"lpush", "x", "2", "1"}); });
fb0.join();
fb1.join();

View File

@ -251,6 +251,35 @@ bool IsSHA(string_view str) {
return true;
}
bool IsTransactional(const CommandId* cid) {
if (cid->first_key_pos() > 0 || (cid->opt_mask() & CO::GLOBAL_TRANS))
return true;
string_view name{cid->name()};
if (name == "EVAL" || name == "EVALSHA")
return true;
return false;
}
bool EvalValidator(CmdArgList args, ConnectionContext* cntx) {
string_view num_keys_str = ArgS(args, 2);
int32_t num_keys;
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0) {
(*cntx)->SendError(kInvalidIntErr);
return false;
}
if (unsigned(num_keys) > args.size() - 3) {
(*cntx)->SendError("Number of keys can't be greater than number of args");
return false;
}
return true;
}
} // namespace
Service::Service(ProactorPool* pp) : pp_(*pp), shard_set_(pp), server_family_(this) {
@ -365,11 +394,23 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendError(WrongNumArgsError(cmd_str));
}
if (under_multi && (cid->opt_mask() & CO::ADMIN)) {
(*cntx)->SendError("Can not run admin commands under multi-transactions");
// Validate more complicated cases with custom validators.
if (!cid->Validate(args, cntx)) {
return;
}
if (under_multi) {
if (cid->opt_mask() & CO::ADMIN) {
(*cntx)->SendError("Can not run admin commands under transactions");
return;
}
if (string_view{cid->name()} == "SELECT") {
(*cntx)->SendError("Can not call SELECT within a transaction");
return;
}
}
std::move(multi_error).Cancel();
if (cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE && !is_trans_cmd) {
@ -389,17 +430,19 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
// Create command transaction
intrusive_ptr<Transaction> dist_trans;
if (cid->first_key_pos() > 0 || (cid->opt_mask() & CO::GLOBAL_TRANS)) {
if (!under_script) {
DCHECK(cntx->transaction == nullptr);
if (IsTransactional(cid)) {
dist_trans.reset(new Transaction{cid, &shard_set_});
cntx->transaction = dist_trans.get();
if (cid->first_key_pos() > 0) {
dist_trans->InitByArgs(cntx->conn_state.db_index, args);
cntx->last_command_debug.shards_count = cntx->transaction->unique_shard_cnt();
}
} else {
cntx->transaction = nullptr;
}
}
cntx->cid = cid;
cmd_req.Inc({cid->name()});
@ -411,8 +454,11 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
cntx->last_command_debug.clock = dist_trans->txid();
cntx->last_command_debug.is_ooo = dist_trans->IsOOO();
}
if (!under_script) {
cntx->transaction = nullptr;
}
}
void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view value,
ConnectionContext* cntx) {
@ -508,16 +554,10 @@ void Service::CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionC
}
void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
string_view num_keys_str = ArgS(args, 2);
int32_t num_keys;
uint32_t num_keys;
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0) {
return (*cntx)->SendError(kInvalidIntErr);
}
CHECK(absl::SimpleAtoi(ArgS(args, 2), &num_keys)); // we already validated this
if (unsigned(num_keys) > args.size() - 3) {
return (*cntx)->SendError("Number of keys can't be greater than number of args");
}
string_view body = ArgS(args, 1);
body = absl::StripAsciiWhitespace(body);
@ -547,15 +587,9 @@ void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
void Service::EvalSha(CmdArgList args, ConnectionContext* cntx) {
string_view num_keys_str = ArgS(args, 2);
int32_t num_keys;
uint32_t num_keys;
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0) {
return (*cntx)->SendError(kInvalidIntErr);
}
if (unsigned(num_keys) > args.size() - 3) {
return (*cntx)->SendError("Number of keys can't be greater than number of args");
}
CHECK(absl::SimpleAtoi(num_keys_str, &num_keys));
ToLower(&args[1]);
@ -708,8 +742,8 @@ void Service::RegisterCommands() {
registry_ << CI{"QUIT", CO::READONLY | CO::FAST, 1, 0, 0, 0}.HFUNC(Quit)
<< CI{"MULTI", CO::NOSCRIPT | CO::FAST | CO::LOADING, 1, 0, 0, 0}.HFUNC(Multi)
<< CI{"EVAL", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(Eval)
<< CI{"EVALSHA", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(EvalSha)
<< CI{"EVAL", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(Eval).SetValidator(&EvalValidator)
<< CI{"EVALSHA", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(EvalSha).SetValidator(&EvalValidator)
<< CI{"EXEC", kExecMask, 1, 0, 0, 0}.MFUNC(Exec);
StringFamily::Register(&registry_);

View File

@ -161,8 +161,12 @@ RespVec BaseFamilyTest::Run(initializer_list<std::string_view> list) {
return pp_->at(0)->Await([&] { return this->Run(list); });
}
mu_.lock();
string id = GetId();
return Run(id, list);
}
RespVec BaseFamilyTest::Run(std::string_view id, std::initializer_list<std::string_view> list) {
mu_.lock();
auto [it, inserted] = connections_.emplace(id, nullptr);
if (inserted) {
@ -178,8 +182,12 @@ RespVec BaseFamilyTest::Run(initializer_list<std::string_view> list) {
auto& context = conn->cmd_cntx;
context.shard_set = ess_;
DCHECK(context.transaction == nullptr);
service_->DispatchCommand(cmd_arg_list, &context);
DCHECK(context.transaction == nullptr);
unique_lock lk(mu_);
last_cmd_dbg_info_ = context.last_command_debug;

View File

@ -96,7 +96,10 @@ class BaseFamilyTest : public ::testing::Test {
void TearDown() override;
protected:
RespVec Run(std::initializer_list<std::string_view> list);
RespVec Run(std::string_view id, std::initializer_list<std::string_view> list);
int64_t CheckedInt(std::initializer_list<std::string_view> list);
bool IsLocked(DbIndex db_index, std::string_view key) const;

View File

@ -4,6 +4,8 @@
#include "server/transaction.h"
#include <absl/strings/match.h>
#include "base/logging.h"
#include "server/command_registry.h"
#include "server/db_slice.h"
@ -22,6 +24,43 @@ std::atomic_uint64_t op_seq{1};
[[maybe_unused]] constexpr size_t kTransSize = sizeof(Transaction);
struct KeyIndex {
unsigned start;
unsigned end; // not including
unsigned step;
};
KeyIndex DetermineKeys(const CommandId* cid, const CmdArgList& args) {
DCHECK_EQ(0u, cid->opt_mask() & CO::GLOBAL_TRANS);
KeyIndex key_index;
if (cid->first_key_pos() > 0) {
key_index.start = cid->first_key_pos();
int last = cid->last_key_pos();
key_index.end = last > 0 ? last + 1 : (int(args.size()) + 1 + last);
key_index.step = cid->key_arg_step();
return key_index;
}
string_view name{cid->name()};
if (name == "EVAL" || name == "EVALSHA") {
DCHECK_GE(args.size(), 3u);
uint32_t num_keys;
CHECK(absl::SimpleAtoi(ArgS(args, 2), &num_keys));
key_index.start = 3;
key_index.end = 3 + num_keys;
key_index.step = 1;
return key_index;
}
LOG(FATAL) << "Not supported";
return key_index;
}
} // namespace
struct Transaction::FindFirstProcessor {
@ -106,18 +145,6 @@ Transaction::Transaction(const CommandId* cid, EngineShardSet* ess) : cid_(cid),
multi_.reset(new Multi);
}
trans_options_ = cid_->opt_mask();
bool single_key = cid_->first_key_pos() > 0 && !cid_->is_multi_key();
if (!multi_ && single_key) {
shard_data_.resize(1); // Single key optimization
} else {
// Our shard_data is not sparse, so we must allocate for all threads :(
shard_data_.resize(ess_->size());
}
if (IsGlobal()) {
unique_shard_cnt_ = ess->size();
}
}
Transaction::~Transaction() {
@ -130,8 +157,8 @@ Transaction::~Transaction() {
* a. T spans a single shard and its not multi.
* unique_shard_id_ is predefined before the schedule() is called.
* In that case only a single thread will be scheduled and it will use shard_data[0] just becase
* shard_data.size() = 1. Engine thread can access any data because there is schedule barrier
* between InitByArgs and RunInShard/IsArmedInShard functions.
* shard_data.size() = 1. Coordinator thread can access any data because there is a
* schedule barrier between InitByArgs and RunInShard/IsArmedInShard functions.
* b. T spans multiple shards and its not multi
* In that case multiple threads will be scheduled. Similarly they have a schedule barrier,
* and IsArmedInShard can read any variable from shard_data[x].
@ -145,24 +172,46 @@ Transaction::~Transaction() {
**/
void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
CHECK_GT(args.size(), 1U);
CHECK_LT(size_t(cid_->first_key_pos()), args.size());
DCHECK_EQ(unique_shard_cnt_, 0u);
DCHECK(!IsGlobal()) << "Global transactions do not have keys";
db_index_ = index;
if (!multi_ && !cid_->is_multi_key()) { // Single key optimization.
auto key = ArgS(args, cid_->first_key_pos());
if (IsGlobal()) {
unique_shard_cnt_ = ess_->size();
shard_data_.resize(unique_shard_cnt_);
return;
}
CHECK_GT(args.size(), 1U); // first entry is the command name.
DCHECK_EQ(unique_shard_cnt_, 0u);
KeyIndex key_index = DetermineKeys(cid_, args);
if (key_index.start == args.size()) {
CHECK(absl::StartsWith(cid_->name(), "EVAL"));
return;
}
DCHECK_LT(key_index.start, args.size());
bool single_key = key_index.start > 0 && !cid_->is_multi_key();
if (!multi_ && single_key) {
shard_data_.resize(1); // Single key optimization
} else {
// Our shard_data is not sparse, so we must allocate for all threads :(
shard_data_.resize(ess_->size());
}
if (!multi_ && single_key) { // Single key optimization.
auto key = ArgS(args, key_index.start);
args_.push_back(key);
unique_shard_cnt_ = 1;
unique_shard_id_ = Shard(key, ess_->size());
return;
}
CHECK(cid_->key_arg_step() == 1 || cid_->key_arg_step() == 2);
DCHECK(cid_->key_arg_step() == 1 || (args.size() % 2) == 1);
CHECK(key_index.step == 1 || key_index.step == 2);
DCHECK(key_index.step == 1 || (args.size() % 2) == 1);
// Reuse thread-local temporary storage. Since this code is non-preemptive we can use it here.
auto& shard_index = tmp_space.shard_cache;
@ -171,6 +220,8 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
v.Clear();
}
// TODO: to determine correctly locking mode for transactions, scripts
// and regular commands.
IntentLock::Mode mode = IntentLock::EXCLUSIVE;
if (multi_) {
mode = Mode();
@ -178,10 +229,8 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
DCHECK_LT(int(mode), 2);
}
size_t key_end = cid_->last_key_pos() > 0 ? cid_->last_key_pos() + 1
: (args.size() + 1 + cid_->last_key_pos());
for (size_t i = 1; i < key_end; ++i) {
std::string_view key = ArgS(args, i);
for (unsigned i = key_index.start; i < key_index.end; ++i) {
string_view key = ArgS(args, i);
uint32_t sid = Shard(key, shard_data_.size());
shard_index[sid].args.push_back(key);
shard_index[sid].original_index.push_back(i - 1);
@ -190,7 +239,7 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
multi_->locks[key].cnt[int(mode)]++;
};
if (cid_->key_arg_step() == 2) { // value
if (key_index.step == 2) { // value
++i;
auto val = ArgS(args, i);
shard_index[sid].args.push_back(val);
@ -198,7 +247,7 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
}
}
args_.resize(key_end - 1);
args_.resize(key_index.end - key_index.start);
reverse_index_.resize(args_.size());
auto next_arg = args_.begin();
@ -209,7 +258,9 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
for (size_t i = 0; i < shard_data_.size(); ++i) {
auto& sd = shard_data_[i];
auto& si = shard_index[i];
CHECK_LT(si.args.size(), 1u << 15);
sd.arg_count = si.args.size();
sd.arg_start = next_arg - args_.begin();
sd.local_mask = 0;