Dragonfly Dispatch is called from lua script

This commit is contained in:
Roman Gershman 2022-02-05 21:04:32 +02:00
parent 07df3f2b95
commit cc53bde091
4 changed files with 166 additions and 7 deletions

View File

@ -90,6 +90,12 @@ class ConnectionContext {
return rbuilder_.get(); return rbuilder_.get();
} }
ReplyBuilderInterface* Inject(ReplyBuilderInterface* new_i) {
ReplyBuilderInterface* res = rbuilder_.release();
rbuilder_.reset(new_i);
return res;
}
private: private:
Connection* owner_; Connection* owner_;
std::unique_ptr<ReplyBuilderInterface> rbuilder_; std::unique_ptr<ReplyBuilderInterface> rbuilder_;

View File

@ -218,8 +218,11 @@ TEST_F(DflyEngineTest, Eval) {
auto resp = Run({"eval", "return 43", "0"}); auto resp = Run({"eval", "return 43", "0"});
EXPECT_THAT(resp[0], IntArg(43)); EXPECT_THAT(resp[0], IntArg(43));
// resp = Run({"eval", "return redis.call('get', 'foo')", "0"}); resp = Run({"incrby", "foo", "42"});
// EXPECT_THAT(resp[0], IntArg(42)); // TODO. EXPECT_THAT(resp[0], IntArg(42));
resp = Run({"eval", "return redis.call('get', 'foo')", "0"});
EXPECT_THAT(resp[0], StrArg("42"));
} }
TEST_F(DflyEngineTest, EvalSha) { TEST_F(DflyEngineTest, EvalSha) {

View File

@ -50,6 +50,39 @@ metrics::CounterFamily cmd_req("requests_total", "Number of served redis request
constexpr size_t kMaxThreadSize = 1024; constexpr size_t kMaxThreadSize = 1024;
class InterpreterReplier : public RedisReplyBuilder {
public:
InterpreterReplier(ObjectExplorer* explr) : RedisReplyBuilder(nullptr), explr_(explr) {
}
void SendError(std::string_view str) override;
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) override;
void SendGetNotFound() override;
void SendStored() override;
void SendSimpleString(std::string_view str) final;
void SendMGetResponse(const StrOrNil* arr, uint32_t count) final;
void SendSimpleStrArr(const std::string_view* arr, uint32_t count) final;
void SendNullArray() final;
void SendStringArr(absl::Span<const std::string_view> arr) final;
void SendNull() final;
void SendLong(long val) final;
void SendDouble(double val) final;
void SendBulkString(std::string_view str) final;
void StartArray(unsigned len) final;
private:
void PostItem();
ObjectExplorer* explr_;
vector<pair<unsigned, unsigned>> array_len_;
unsigned num_elems_ = 0;
};
class EvalSerializer : public ObjectExplorer { class EvalSerializer : public ObjectExplorer {
public: public:
EvalSerializer(RedisReplyBuilder* rb) : rb_(rb) { EvalSerializer(RedisReplyBuilder* rb) : rb_(rb) {
@ -63,7 +96,7 @@ class EvalSerializer : public ObjectExplorer {
} }
} }
void OnString(std::string_view str) final { void OnString(string_view str) final {
rb_->SendBulkString(str); rb_->SendBulkString(str);
} }
@ -87,11 +120,11 @@ class EvalSerializer : public ObjectExplorer {
rb_->SendNull(); rb_->SendNull();
} }
void OnStatus(std::string_view str) { void OnStatus(string_view str) {
rb_->SendSimpleString(str); rb_->SendSimpleString(str);
} }
void OnError(std::string_view str) { void OnError(string_view str) {
rb_->SendError(str); rb_->SendError(str);
} }
@ -99,6 +132,117 @@ class EvalSerializer : public ObjectExplorer {
RedisReplyBuilder* rb_; RedisReplyBuilder* rb_;
}; };
void InterpreterReplier::PostItem() {
if (array_len_.empty()) {
DCHECK_EQ(0u, num_elems_);
++num_elems_;
} else {
++num_elems_;
while (num_elems_ == array_len_.back().second) {
num_elems_ = array_len_.back().first;
explr_->OnArrayEnd();
array_len_.pop_back();
if (array_len_.empty())
break;
}
}
}
void InterpreterReplier::SendError(string_view str) {
DCHECK(array_len_.empty());
explr_->OnError(str);
}
void InterpreterReplier::SendGetReply(string_view key, uint32_t flags, string_view value) {
DCHECK(array_len_.empty());
explr_->OnString(value);
}
void InterpreterReplier::SendGetNotFound() {
DCHECK(array_len_.empty());
explr_->OnNil();
}
void InterpreterReplier::SendStored() {
DCHECK(array_len_.empty());
SendSimpleString("OK");
}
void InterpreterReplier::SendSimpleString(string_view str) {
if (array_len_.empty())
explr_->OnStatus(str);
else
explr_->OnString(str);
PostItem();
}
void InterpreterReplier::SendMGetResponse(const StrOrNil* arr, uint32_t count) {
DCHECK(array_len_.empty());
explr_->OnArrayStart(count);
for (uint32_t i = 0; i < count; ++i) {
if (arr[i].has_value()) {
explr_->OnString(*arr[i]);
} else {
explr_->OnNil();
}
}
explr_->OnArrayEnd();
}
void InterpreterReplier::SendSimpleStrArr(const string_view* arr, uint32_t count) {
explr_->OnArrayStart(count);
for (uint32_t i = 0; i < count; ++i) {
explr_->OnString(arr[i]);
}
explr_->OnArrayEnd();
PostItem();
}
void InterpreterReplier::SendNullArray() {
SendSimpleStrArr(nullptr, 0);
PostItem();
}
void InterpreterReplier::SendStringArr(absl::Span<const string_view> arr) {
SendSimpleStrArr(arr.data(), arr.size());
PostItem();
}
void InterpreterReplier::SendNull() {
explr_->OnNil();
PostItem();
}
void InterpreterReplier::SendLong(long val) {
explr_->OnInt(val);
PostItem();
}
void InterpreterReplier::SendDouble(double val) {
explr_->OnDouble(val);
PostItem();
}
void InterpreterReplier::SendBulkString(string_view str) {
explr_->OnString(str);
PostItem();
}
void InterpreterReplier::StartArray(unsigned len) {
explr_->OnArrayStart(len);
if (len == 0) {
explr_->OnArrayEnd();
PostItem();
} else {
array_len_.emplace_back(num_elems_ + 1, len);
num_elems_ = 0;
}
}
bool IsSHA(string_view str) { bool IsSHA(string_view str) {
for (auto c : str) { for (auto c : str) {
if (!absl::ascii_isxdigit(c)) if (!absl::ascii_isxdigit(c))
@ -355,8 +499,12 @@ void Service::Multi(CmdArgList args, ConnectionContext* cntx) {
} }
void Service::CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionContext* cntx) { void Service::CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionContext* cntx) {
// reply->OnInt(42); InterpreterReplier replier(reply);
ReplyBuilderInterface* orig = cntx->Inject(&replier);
DispatchCommand(std::move(args), cntx); DispatchCommand(std::move(args), cntx);
cntx->Inject(orig);
} }
void Service::Eval(CmdArgList args, ConnectionContext* cntx) { void Service::Eval(CmdArgList args, ConnectionContext* cntx) {

View File

@ -36,6 +36,8 @@ void SinkReplyBuilder::CloseConnection() {
} }
void SinkReplyBuilder::Send(const iovec* v, uint32_t len) { void SinkReplyBuilder::Send(const iovec* v, uint32_t len) {
DCHECK(sink_);
if (should_batch_) { if (should_batch_) {
// TODO: to introduce flushing when too much data is batched. // TODO: to introduce flushing when too much data is batched.
for (unsigned i = 0; i < len; ++i) { for (unsigned i = 0; i < len; ++i) {