diff --git a/server/conn_context.h b/server/conn_context.h index d832bbc..39e59cb 100644 --- a/server/conn_context.h +++ b/server/conn_context.h @@ -90,6 +90,12 @@ class ConnectionContext { return rbuilder_.get(); } + ReplyBuilderInterface* Inject(ReplyBuilderInterface* new_i) { + ReplyBuilderInterface* res = rbuilder_.release(); + rbuilder_.reset(new_i); + return res; + } + private: Connection* owner_; std::unique_ptr rbuilder_; diff --git a/server/dragonfly_test.cc b/server/dragonfly_test.cc index 0cb8858..6da3681 100644 --- a/server/dragonfly_test.cc +++ b/server/dragonfly_test.cc @@ -218,8 +218,11 @@ TEST_F(DflyEngineTest, Eval) { auto resp = Run({"eval", "return 43", "0"}); EXPECT_THAT(resp[0], IntArg(43)); - // resp = Run({"eval", "return redis.call('get', 'foo')", "0"}); - // EXPECT_THAT(resp[0], IntArg(42)); // TODO. + resp = Run({"incrby", "foo", "42"}); + EXPECT_THAT(resp[0], IntArg(42)); + + resp = Run({"eval", "return redis.call('get', 'foo')", "0"}); + EXPECT_THAT(resp[0], StrArg("42")); } TEST_F(DflyEngineTest, EvalSha) { diff --git a/server/main_service.cc b/server/main_service.cc index 7cd7634..4244884 100644 --- a/server/main_service.cc +++ b/server/main_service.cc @@ -50,6 +50,39 @@ metrics::CounterFamily cmd_req("requests_total", "Number of served redis request constexpr size_t kMaxThreadSize = 1024; +class InterpreterReplier : public RedisReplyBuilder { + public: + InterpreterReplier(ObjectExplorer* explr) : RedisReplyBuilder(nullptr), explr_(explr) { + } + + void SendError(std::string_view str) override; + void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) override; + void SendGetNotFound() override; + void SendStored() override; + + void SendSimpleString(std::string_view str) final; + void SendMGetResponse(const StrOrNil* arr, uint32_t count) final; + void SendSimpleStrArr(const std::string_view* arr, uint32_t count) final; + void SendNullArray() final; + + void SendStringArr(absl::Span arr) final; + void SendNull() final; + + void SendLong(long val) final; + void SendDouble(double val) final; + + void SendBulkString(std::string_view str) final; + + void StartArray(unsigned len) final; + + private: + void PostItem(); + + ObjectExplorer* explr_; + vector> array_len_; + unsigned num_elems_ = 0; +}; + class EvalSerializer : public ObjectExplorer { public: EvalSerializer(RedisReplyBuilder* rb) : rb_(rb) { @@ -63,7 +96,7 @@ class EvalSerializer : public ObjectExplorer { } } - void OnString(std::string_view str) final { + void OnString(string_view str) final { rb_->SendBulkString(str); } @@ -87,11 +120,11 @@ class EvalSerializer : public ObjectExplorer { rb_->SendNull(); } - void OnStatus(std::string_view str) { + void OnStatus(string_view str) { rb_->SendSimpleString(str); } - void OnError(std::string_view str) { + void OnError(string_view str) { rb_->SendError(str); } @@ -99,6 +132,117 @@ class EvalSerializer : public ObjectExplorer { RedisReplyBuilder* rb_; }; +void InterpreterReplier::PostItem() { + if (array_len_.empty()) { + DCHECK_EQ(0u, num_elems_); + ++num_elems_; + } else { + ++num_elems_; + + while (num_elems_ == array_len_.back().second) { + num_elems_ = array_len_.back().first; + explr_->OnArrayEnd(); + + array_len_.pop_back(); + if (array_len_.empty()) + break; + } + } +} + +void InterpreterReplier::SendError(string_view str) { + DCHECK(array_len_.empty()); + explr_->OnError(str); +} + +void InterpreterReplier::SendGetReply(string_view key, uint32_t flags, string_view value) { + DCHECK(array_len_.empty()); + explr_->OnString(value); +} + +void InterpreterReplier::SendGetNotFound() { + DCHECK(array_len_.empty()); + explr_->OnNil(); +} + +void InterpreterReplier::SendStored() { + DCHECK(array_len_.empty()); + SendSimpleString("OK"); +} + +void InterpreterReplier::SendSimpleString(string_view str) { + if (array_len_.empty()) + explr_->OnStatus(str); + else + explr_->OnString(str); + PostItem(); +} + +void InterpreterReplier::SendMGetResponse(const StrOrNil* arr, uint32_t count) { + DCHECK(array_len_.empty()); + + explr_->OnArrayStart(count); + for (uint32_t i = 0; i < count; ++i) { + if (arr[i].has_value()) { + explr_->OnString(*arr[i]); + } else { + explr_->OnNil(); + } + } + explr_->OnArrayEnd(); +} + +void InterpreterReplier::SendSimpleStrArr(const string_view* arr, uint32_t count) { + explr_->OnArrayStart(count); + for (uint32_t i = 0; i < count; ++i) { + explr_->OnString(arr[i]); + } + explr_->OnArrayEnd(); + PostItem(); +} + +void InterpreterReplier::SendNullArray() { + SendSimpleStrArr(nullptr, 0); + PostItem(); +} + +void InterpreterReplier::SendStringArr(absl::Span arr) { + SendSimpleStrArr(arr.data(), arr.size()); + PostItem(); +} + +void InterpreterReplier::SendNull() { + explr_->OnNil(); + PostItem(); +} + +void InterpreterReplier::SendLong(long val) { + explr_->OnInt(val); + PostItem(); +} + +void InterpreterReplier::SendDouble(double val) { + explr_->OnDouble(val); + PostItem(); +} + +void InterpreterReplier::SendBulkString(string_view str) { + explr_->OnString(str); + PostItem(); +} + +void InterpreterReplier::StartArray(unsigned len) { + explr_->OnArrayStart(len); + + if (len == 0) { + explr_->OnArrayEnd(); + PostItem(); + } else { + array_len_.emplace_back(num_elems_ + 1, len); + num_elems_ = 0; + } +} + bool IsSHA(string_view str) { for (auto c : str) { if (!absl::ascii_isxdigit(c)) @@ -355,8 +499,12 @@ void Service::Multi(CmdArgList args, ConnectionContext* cntx) { } void Service::CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionContext* cntx) { - // reply->OnInt(42); + InterpreterReplier replier(reply); + ReplyBuilderInterface* orig = cntx->Inject(&replier); + DispatchCommand(std::move(args), cntx); + + cntx->Inject(orig); } void Service::Eval(CmdArgList args, ConnectionContext* cntx) { @@ -475,7 +623,7 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter, Interpreter::RunResult result = interpreter->RunFunction(eval_args.sha, &error); - cntx->conn_state.script_info.reset(); // reset script_info + cntx->conn_state.script_info.reset(); // reset script_info if (result == Interpreter::RUN_ERR) { string resp = absl::StrCat("Error running script (call to ", eval_args.sha, "): ", error); diff --git a/server/reply_builder.cc b/server/reply_builder.cc index 37cf972..ffc8e00 100644 --- a/server/reply_builder.cc +++ b/server/reply_builder.cc @@ -36,6 +36,8 @@ void SinkReplyBuilder::CloseConnection() { } void SinkReplyBuilder::Send(const iovec* v, uint32_t len) { + DCHECK(sink_); + if (should_batch_) { // TODO: to introduce flushing when too much data is batched. for (unsigned i = 0; i < len; ++i) {