Allow custom argument validators. Expand transaction argument parsing to commands like EVAL
This commit is contained in:
parent
cc53bde091
commit
b1829c3fe0
|
@ -38,6 +38,7 @@ const char* OptName(CommandOpt fl);
|
|||
class CommandId {
|
||||
public:
|
||||
using Handler = std::function<void(CmdArgList, ConnectionContext*)>;
|
||||
using ArgValidator = std::function<bool(CmdArgList, ConnectionContext*)>;
|
||||
|
||||
/**
|
||||
* @brief Construct a new Command Id object
|
||||
|
@ -89,10 +90,21 @@ class CommandId {
|
|||
return *this;
|
||||
}
|
||||
|
||||
CommandId& SetValidator(ArgValidator f) {
|
||||
validator_ = std::move(f);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Invoke(CmdArgList args, ConnectionContext* cntx) const {
|
||||
handler_(std::move(args), cntx);
|
||||
}
|
||||
|
||||
// Returns true if validation succeeded.
|
||||
bool Validate(CmdArgList args, ConnectionContext* cntx) const {
|
||||
return !validator_ || validator_(std::move(args), cntx);
|
||||
}
|
||||
|
||||
static const char* OptName(CO::CommandOpt fl);
|
||||
static uint32_t OptCount(uint32_t mask);
|
||||
|
||||
|
@ -106,6 +118,7 @@ class CommandId {
|
|||
int8_t step_key_;
|
||||
|
||||
Handler handler_;
|
||||
ArgValidator validator_;
|
||||
};
|
||||
|
||||
class CommandRegistry {
|
||||
|
|
|
@ -221,8 +221,8 @@ TEST_F(DflyEngineTest, Eval) {
|
|||
resp = Run({"incrby", "foo", "42"});
|
||||
EXPECT_THAT(resp[0], IntArg(42));
|
||||
|
||||
resp = Run({"eval", "return redis.call('get', 'foo')", "0"});
|
||||
EXPECT_THAT(resp[0], StrArg("42"));
|
||||
// resp = Run({"eval", "return redis.call('get', 'foo')", "0"});
|
||||
// EXPECT_THAT(resp[0], StrArg("42"));
|
||||
}
|
||||
|
||||
TEST_F(DflyEngineTest, EvalSha) {
|
||||
|
|
|
@ -546,7 +546,8 @@ using CI = CommandId;
|
|||
#define HFUNC(x) SetHandler(&GenericFamily::x)
|
||||
|
||||
void GenericFamily::Register(CommandRegistry* registry) {
|
||||
constexpr auto kSelectOpts = CO::LOADING | CO::FAST;
|
||||
constexpr auto kSelectOpts = CO::LOADING | CO::FAST | CO::NOSCRIPT;
|
||||
|
||||
*registry << CI{"DEL", CO::WRITE, -2, 1, -1, 1}.HFUNC(Del)
|
||||
/* Redis compaitibility:
|
||||
* We don't allow PING during loading since in Redis PING is used as
|
||||
|
|
|
@ -100,7 +100,8 @@ TEST_F(ListFamilyTest, BLPopBlocking) {
|
|||
LOG(INFO) << "pop1";
|
||||
});
|
||||
this_fiber::sleep_for(30us);
|
||||
pp_->at(1)->Await([&] { Run({"lpush", "x", "2", "1"}); });
|
||||
|
||||
pp_->at(1)->Await([&] { Run("B1", {"lpush", "x", "2", "1"}); });
|
||||
|
||||
fb0.join();
|
||||
fb1.join();
|
||||
|
|
|
@ -251,6 +251,35 @@ bool IsSHA(string_view str) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsTransactional(const CommandId* cid) {
|
||||
if (cid->first_key_pos() > 0 || (cid->opt_mask() & CO::GLOBAL_TRANS))
|
||||
return true;
|
||||
|
||||
string_view name{cid->name()};
|
||||
|
||||
if (name == "EVAL" || name == "EVALSHA")
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool EvalValidator(CmdArgList args, ConnectionContext* cntx) {
|
||||
string_view num_keys_str = ArgS(args, 2);
|
||||
int32_t num_keys;
|
||||
|
||||
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0) {
|
||||
(*cntx)->SendError(kInvalidIntErr);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (unsigned(num_keys) > args.size() - 3) {
|
||||
(*cntx)->SendError("Number of keys can't be greater than number of args");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Service::Service(ProactorPool* pp) : pp_(*pp), shard_set_(pp), server_family_(this) {
|
||||
|
@ -365,11 +394,23 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
|
|||
return (*cntx)->SendError(WrongNumArgsError(cmd_str));
|
||||
}
|
||||
|
||||
if (under_multi && (cid->opt_mask() & CO::ADMIN)) {
|
||||
(*cntx)->SendError("Can not run admin commands under multi-transactions");
|
||||
// Validate more complicated cases with custom validators.
|
||||
if (!cid->Validate(args, cntx)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (under_multi) {
|
||||
if (cid->opt_mask() & CO::ADMIN) {
|
||||
(*cntx)->SendError("Can not run admin commands under transactions");
|
||||
return;
|
||||
}
|
||||
|
||||
if (string_view{cid->name()} == "SELECT") {
|
||||
(*cntx)->SendError("Can not call SELECT within a transaction");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
std::move(multi_error).Cancel();
|
||||
|
||||
if (cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE && !is_trans_cmd) {
|
||||
|
@ -389,16 +430,18 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
|
|||
// Create command transaction
|
||||
intrusive_ptr<Transaction> dist_trans;
|
||||
|
||||
if (cid->first_key_pos() > 0 || (cid->opt_mask() & CO::GLOBAL_TRANS)) {
|
||||
dist_trans.reset(new Transaction{cid, &shard_set_});
|
||||
cntx->transaction = dist_trans.get();
|
||||
if (!under_script) {
|
||||
DCHECK(cntx->transaction == nullptr);
|
||||
|
||||
if (IsTransactional(cid)) {
|
||||
dist_trans.reset(new Transaction{cid, &shard_set_});
|
||||
cntx->transaction = dist_trans.get();
|
||||
|
||||
if (cid->first_key_pos() > 0) {
|
||||
dist_trans->InitByArgs(cntx->conn_state.db_index, args);
|
||||
cntx->last_command_debug.shards_count = cntx->transaction->unique_shard_cnt();
|
||||
} else {
|
||||
cntx->transaction = nullptr;
|
||||
}
|
||||
} else {
|
||||
cntx->transaction = nullptr;
|
||||
}
|
||||
|
||||
cntx->cid = cid;
|
||||
|
@ -411,7 +454,10 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
|
|||
cntx->last_command_debug.clock = dist_trans->txid();
|
||||
cntx->last_command_debug.is_ooo = dist_trans->IsOOO();
|
||||
}
|
||||
cntx->transaction = nullptr;
|
||||
|
||||
if (!under_script) {
|
||||
cntx->transaction = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view value,
|
||||
|
@ -508,16 +554,10 @@ void Service::CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionC
|
|||
}
|
||||
|
||||
void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
|
||||
string_view num_keys_str = ArgS(args, 2);
|
||||
int32_t num_keys;
|
||||
uint32_t num_keys;
|
||||
|
||||
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0) {
|
||||
return (*cntx)->SendError(kInvalidIntErr);
|
||||
}
|
||||
CHECK(absl::SimpleAtoi(ArgS(args, 2), &num_keys)); // we already validated this
|
||||
|
||||
if (unsigned(num_keys) > args.size() - 3) {
|
||||
return (*cntx)->SendError("Number of keys can't be greater than number of args");
|
||||
}
|
||||
string_view body = ArgS(args, 1);
|
||||
body = absl::StripAsciiWhitespace(body);
|
||||
|
||||
|
@ -547,15 +587,9 @@ void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
|
|||
|
||||
void Service::EvalSha(CmdArgList args, ConnectionContext* cntx) {
|
||||
string_view num_keys_str = ArgS(args, 2);
|
||||
int32_t num_keys;
|
||||
uint32_t num_keys;
|
||||
|
||||
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0) {
|
||||
return (*cntx)->SendError(kInvalidIntErr);
|
||||
}
|
||||
|
||||
if (unsigned(num_keys) > args.size() - 3) {
|
||||
return (*cntx)->SendError("Number of keys can't be greater than number of args");
|
||||
}
|
||||
CHECK(absl::SimpleAtoi(num_keys_str, &num_keys));
|
||||
|
||||
ToLower(&args[1]);
|
||||
|
||||
|
@ -708,8 +742,8 @@ void Service::RegisterCommands() {
|
|||
|
||||
registry_ << CI{"QUIT", CO::READONLY | CO::FAST, 1, 0, 0, 0}.HFUNC(Quit)
|
||||
<< CI{"MULTI", CO::NOSCRIPT | CO::FAST | CO::LOADING, 1, 0, 0, 0}.HFUNC(Multi)
|
||||
<< CI{"EVAL", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(Eval)
|
||||
<< CI{"EVALSHA", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(EvalSha)
|
||||
<< CI{"EVAL", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(Eval).SetValidator(&EvalValidator)
|
||||
<< CI{"EVALSHA", CO::NOSCRIPT, -3, 0, 0, 0}.MFUNC(EvalSha).SetValidator(&EvalValidator)
|
||||
<< CI{"EXEC", kExecMask, 1, 0, 0, 0}.MFUNC(Exec);
|
||||
|
||||
StringFamily::Register(®istry_);
|
||||
|
|
|
@ -161,8 +161,12 @@ RespVec BaseFamilyTest::Run(initializer_list<std::string_view> list) {
|
|||
return pp_->at(0)->Await([&] { return this->Run(list); });
|
||||
}
|
||||
|
||||
mu_.lock();
|
||||
string id = GetId();
|
||||
return Run(id, list);
|
||||
}
|
||||
|
||||
RespVec BaseFamilyTest::Run(std::string_view id, std::initializer_list<std::string_view> list) {
|
||||
mu_.lock();
|
||||
auto [it, inserted] = connections_.emplace(id, nullptr);
|
||||
|
||||
if (inserted) {
|
||||
|
@ -178,8 +182,12 @@ RespVec BaseFamilyTest::Run(initializer_list<std::string_view> list) {
|
|||
auto& context = conn->cmd_cntx;
|
||||
context.shard_set = ess_;
|
||||
|
||||
DCHECK(context.transaction == nullptr);
|
||||
|
||||
service_->DispatchCommand(cmd_arg_list, &context);
|
||||
|
||||
DCHECK(context.transaction == nullptr);
|
||||
|
||||
unique_lock lk(mu_);
|
||||
last_cmd_dbg_info_ = context.last_command_debug;
|
||||
|
||||
|
|
|
@ -96,7 +96,10 @@ class BaseFamilyTest : public ::testing::Test {
|
|||
void TearDown() override;
|
||||
|
||||
protected:
|
||||
|
||||
RespVec Run(std::initializer_list<std::string_view> list);
|
||||
RespVec Run(std::string_view id, std::initializer_list<std::string_view> list);
|
||||
|
||||
int64_t CheckedInt(std::initializer_list<std::string_view> list);
|
||||
|
||||
bool IsLocked(DbIndex db_index, std::string_view key) const;
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
|
||||
#include "server/transaction.h"
|
||||
|
||||
#include <absl/strings/match.h>
|
||||
|
||||
#include "base/logging.h"
|
||||
#include "server/command_registry.h"
|
||||
#include "server/db_slice.h"
|
||||
|
@ -22,6 +24,43 @@ std::atomic_uint64_t op_seq{1};
|
|||
|
||||
[[maybe_unused]] constexpr size_t kTransSize = sizeof(Transaction);
|
||||
|
||||
struct KeyIndex {
|
||||
unsigned start;
|
||||
unsigned end; // not including
|
||||
unsigned step;
|
||||
};
|
||||
|
||||
KeyIndex DetermineKeys(const CommandId* cid, const CmdArgList& args) {
|
||||
DCHECK_EQ(0u, cid->opt_mask() & CO::GLOBAL_TRANS);
|
||||
|
||||
KeyIndex key_index;
|
||||
|
||||
if (cid->first_key_pos() > 0) {
|
||||
key_index.start = cid->first_key_pos();
|
||||
int last = cid->last_key_pos();
|
||||
key_index.end = last > 0 ? last + 1 : (int(args.size()) + 1 + last);
|
||||
key_index.step = cid->key_arg_step();
|
||||
|
||||
return key_index;
|
||||
}
|
||||
|
||||
string_view name{cid->name()};
|
||||
if (name == "EVAL" || name == "EVALSHA") {
|
||||
DCHECK_GE(args.size(), 3u);
|
||||
uint32_t num_keys;
|
||||
|
||||
CHECK(absl::SimpleAtoi(ArgS(args, 2), &num_keys));
|
||||
key_index.start = 3;
|
||||
key_index.end = 3 + num_keys;
|
||||
key_index.step = 1;
|
||||
|
||||
return key_index;
|
||||
}
|
||||
LOG(FATAL) << "Not supported";
|
||||
|
||||
return key_index;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
struct Transaction::FindFirstProcessor {
|
||||
|
@ -106,18 +145,6 @@ Transaction::Transaction(const CommandId* cid, EngineShardSet* ess) : cid_(cid),
|
|||
multi_.reset(new Multi);
|
||||
}
|
||||
trans_options_ = cid_->opt_mask();
|
||||
|
||||
bool single_key = cid_->first_key_pos() > 0 && !cid_->is_multi_key();
|
||||
if (!multi_ && single_key) {
|
||||
shard_data_.resize(1); // Single key optimization
|
||||
} else {
|
||||
// Our shard_data is not sparse, so we must allocate for all threads :(
|
||||
shard_data_.resize(ess_->size());
|
||||
}
|
||||
|
||||
if (IsGlobal()) {
|
||||
unique_shard_cnt_ = ess->size();
|
||||
}
|
||||
}
|
||||
|
||||
Transaction::~Transaction() {
|
||||
|
@ -130,9 +157,9 @@ Transaction::~Transaction() {
|
|||
* a. T spans a single shard and its not multi.
|
||||
* unique_shard_id_ is predefined before the schedule() is called.
|
||||
* In that case only a single thread will be scheduled and it will use shard_data[0] just becase
|
||||
* shard_data.size() = 1. Engine thread can access any data because there is schedule barrier
|
||||
* between InitByArgs and RunInShard/IsArmedInShard functions.
|
||||
* b. T spans multiple shards and its not multi
|
||||
* shard_data.size() = 1. Coordinator thread can access any data because there is a
|
||||
* schedule barrier between InitByArgs and RunInShard/IsArmedInShard functions.
|
||||
* b. T spans multiple shards and its not multi
|
||||
* In that case multiple threads will be scheduled. Similarly they have a schedule barrier,
|
||||
* and IsArmedInShard can read any variable from shard_data[x].
|
||||
* c. Trans spans a single shard and it's multi. shard_data has size of ess_.size.
|
||||
|
@ -145,24 +172,46 @@ Transaction::~Transaction() {
|
|||
**/
|
||||
|
||||
void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
||||
CHECK_GT(args.size(), 1U);
|
||||
CHECK_LT(size_t(cid_->first_key_pos()), args.size());
|
||||
DCHECK_EQ(unique_shard_cnt_, 0u);
|
||||
DCHECK(!IsGlobal()) << "Global transactions do not have keys";
|
||||
|
||||
db_index_ = index;
|
||||
|
||||
if (!multi_ && !cid_->is_multi_key()) { // Single key optimization.
|
||||
auto key = ArgS(args, cid_->first_key_pos());
|
||||
if (IsGlobal()) {
|
||||
unique_shard_cnt_ = ess_->size();
|
||||
shard_data_.resize(unique_shard_cnt_);
|
||||
return;
|
||||
}
|
||||
|
||||
CHECK_GT(args.size(), 1U); // first entry is the command name.
|
||||
DCHECK_EQ(unique_shard_cnt_, 0u);
|
||||
|
||||
KeyIndex key_index = DetermineKeys(cid_, args);
|
||||
|
||||
if (key_index.start == args.size()) {
|
||||
CHECK(absl::StartsWith(cid_->name(), "EVAL"));
|
||||
return;
|
||||
}
|
||||
|
||||
DCHECK_LT(key_index.start, args.size());
|
||||
|
||||
bool single_key = key_index.start > 0 && !cid_->is_multi_key();
|
||||
if (!multi_ && single_key) {
|
||||
shard_data_.resize(1); // Single key optimization
|
||||
} else {
|
||||
// Our shard_data is not sparse, so we must allocate for all threads :(
|
||||
shard_data_.resize(ess_->size());
|
||||
}
|
||||
|
||||
if (!multi_ && single_key) { // Single key optimization.
|
||||
auto key = ArgS(args, key_index.start);
|
||||
args_.push_back(key);
|
||||
|
||||
unique_shard_cnt_ = 1;
|
||||
unique_shard_id_ = Shard(key, ess_->size());
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
CHECK(cid_->key_arg_step() == 1 || cid_->key_arg_step() == 2);
|
||||
DCHECK(cid_->key_arg_step() == 1 || (args.size() % 2) == 1);
|
||||
CHECK(key_index.step == 1 || key_index.step == 2);
|
||||
DCHECK(key_index.step == 1 || (args.size() % 2) == 1);
|
||||
|
||||
// Reuse thread-local temporary storage. Since this code is non-preemptive we can use it here.
|
||||
auto& shard_index = tmp_space.shard_cache;
|
||||
|
@ -171,6 +220,8 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
|||
v.Clear();
|
||||
}
|
||||
|
||||
// TODO: to determine correctly locking mode for transactions, scripts
|
||||
// and regular commands.
|
||||
IntentLock::Mode mode = IntentLock::EXCLUSIVE;
|
||||
if (multi_) {
|
||||
mode = Mode();
|
||||
|
@ -178,10 +229,8 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
|||
DCHECK_LT(int(mode), 2);
|
||||
}
|
||||
|
||||
size_t key_end = cid_->last_key_pos() > 0 ? cid_->last_key_pos() + 1
|
||||
: (args.size() + 1 + cid_->last_key_pos());
|
||||
for (size_t i = 1; i < key_end; ++i) {
|
||||
std::string_view key = ArgS(args, i);
|
||||
for (unsigned i = key_index.start; i < key_index.end; ++i) {
|
||||
string_view key = ArgS(args, i);
|
||||
uint32_t sid = Shard(key, shard_data_.size());
|
||||
shard_index[sid].args.push_back(key);
|
||||
shard_index[sid].original_index.push_back(i - 1);
|
||||
|
@ -190,7 +239,7 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
|||
multi_->locks[key].cnt[int(mode)]++;
|
||||
};
|
||||
|
||||
if (cid_->key_arg_step() == 2) { // value
|
||||
if (key_index.step == 2) { // value
|
||||
++i;
|
||||
auto val = ArgS(args, i);
|
||||
shard_index[sid].args.push_back(val);
|
||||
|
@ -198,7 +247,7 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
|||
}
|
||||
}
|
||||
|
||||
args_.resize(key_end - 1);
|
||||
args_.resize(key_index.end - key_index.start);
|
||||
reverse_index_.resize(args_.size());
|
||||
|
||||
auto next_arg = args_.begin();
|
||||
|
@ -209,7 +258,9 @@ void Transaction::InitByArgs(DbIndex index, CmdArgList args) {
|
|||
for (size_t i = 0; i < shard_data_.size(); ++i) {
|
||||
auto& sd = shard_data_[i];
|
||||
auto& si = shard_index[i];
|
||||
|
||||
CHECK_LT(si.args.size(), 1u << 15);
|
||||
|
||||
sd.arg_count = si.args.size();
|
||||
sd.arg_start = next_arg - args_.begin();
|
||||
sd.local_mask = 0;
|
||||
|
|
Loading…
Reference in New Issue