Implement serialization of most lua data types that are returned to caller

This commit is contained in:
Roman Gershman 2022-02-01 13:15:19 +02:00
parent 8fbe19d3c5
commit d1f6f6d410
6 changed files with 346 additions and 26 deletions

View File

@ -8,6 +8,7 @@
#include <openssl/sha.h>
#include <cstring>
#include <optional>
extern "C" {
#include <lauxlib.h>
@ -100,6 +101,20 @@ void ToHex(const uint8_t* src, char* dest) {
dest[40] = '\0';
}
string_view TopSv(lua_State* lua) {
return string_view{lua_tostring(lua, -1), lua_rawlen(lua, -1)};
}
optional<int> FetchKey(lua_State* lua, const char* key) {
lua_pushstring(lua, key);
int type = lua_gettable(lua, -2);
if (type == LUA_TNIL) {
lua_pop(lua, 1);
return nullopt;
}
return type;
}
} // namespace
Interpreter::Interpreter() {
@ -127,24 +142,7 @@ bool Interpreter::AddFunction(string_view body, string* result) {
char funcname[43];
Fingerprint(body, funcname);
string script = absl::StrCat("function ", funcname, "() \n");
absl::StrAppend(&script, body, "\nend");
int res = luaL_loadbuffer(lua_, script.data(), script.size(), "@user_script");
if (res == 0) {
res = lua_pcall(lua_, 0, 0, 0); // run func definition code
}
if (res) {
result->assign(lua_tostring(lua_, -1));
lua_pop(lua_, 1); // Remove the error.
return false;
}
result->assign(funcname);
return true;
return AddInternal(funcname, body, result);
}
bool Interpreter::RunFunction(const char* f_id, std::string* error) {
@ -167,4 +165,118 @@ bool Interpreter::RunFunction(const char* f_id, std::string* error) {
return err == 0;
}
bool Interpreter::Execute(string_view body, char f_id[43], string* error) {
lua_getglobal(lua_, "__redis__err__handler");
Fingerprint(body, f_id);
int type = lua_getglobal(lua_, f_id);
if (type != LUA_TFUNCTION) {
lua_pop(lua_, 1);
if (!AddInternal(f_id, body, error))
return false;
type = lua_getglobal(lua_, f_id);
CHECK_EQ(type, LUA_TFUNCTION);
}
int err = lua_pcall(lua_, 0, 1, -2);
if (err) {
*error = lua_tostring(lua_, -1);
}
return err == 0;
}
bool Interpreter::AddInternal(const char* f_id, string_view body, string* result) {
string script = absl::StrCat("function ", f_id, "() \n");
absl::StrAppend(&script, body, "\nend");
int res = luaL_loadbuffer(lua_, script.data(), script.size(), "@user_script");
if (res == 0) {
res = lua_pcall(lua_, 0, 0, 0); // run func definition code
}
if (res) {
result->assign(lua_tostring(lua_, -1));
lua_pop(lua_, 1); // Remove the error.
return false;
}
result->assign(f_id);
return true;
}
bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
// TODO: to get rid of this check or move it to the external function.
// It does not make sense to do this check recursively and it complicates the flow
// were in the middle of the serialization we could theoretically fail.
if (!lua_checkstack(lua_, 4)) {
/* Increase the Lua stack if needed to make sure there is enough room
* to push 4 elements to the stack. On failure, return error.
* Notice that we need, in the worst case, 4 elements because returning a map might
* require push 4 elements to the Lua stack.*/
error->assign("reached lua stack limit");
lua_pop(lua_, 1); /* pop the element from the stack */
return false;
}
int t = lua_type(lua_, -1);
bool res = true;
switch (t) {
case LUA_TSTRING:
serializer->OnString(TopSv(lua_));
break;
case LUA_TBOOLEAN:
serializer->OnBool(lua_toboolean(lua_, -1));
break;
case LUA_TNUMBER:
if (lua_isinteger(lua_, -1)) {
serializer->OnInt(lua_tointeger(lua_, -1));
} else {
serializer->OnDouble(lua_tonumber(lua_, -1));
}
break;
case LUA_TTABLE: {
unsigned len = lua_rawlen(lua_, -1);
if (len > 0) { // array
serializer->OnArrayStart(len);
for (unsigned i = 0; i < len; ++i) {
t = lua_rawgeti(lua_, -1, i + 1); // push table element
// TODO: we should make sure that we have enough stack space
// to traverse each object. This can be done as a dry-run before doing real serialization.
// Once we are sure we are safe we can simplify the serialization flow and
// remove the error factor.
CHECK(Serialize(serializer, error)); // pops the element
}
serializer->OnArrayEnd();
} else {
auto fres = FetchKey(lua_, "err");
if (fres && *fres == LUA_TSTRING) {
serializer->OnError(TopSv(lua_));
lua_pop(lua_, 1);
break;
}
fres = FetchKey(lua_, "ok");
if (fres && *fres == LUA_TSTRING) {
serializer->OnStatus(TopSv(lua_));
lua_pop(lua_, 1);
break;
}
serializer->OnError("TBD");
}
break;
}
case LUA_TNIL:
serializer->OnNil();
break;
default:
error->assign(absl::StrCat("Unsupported type ", t));
}
lua_pop(lua_, 1);
return res;
}
} // namespace dfly

View File

@ -10,6 +10,21 @@ typedef struct lua_State lua_State;
namespace dfly {
class ObjectExplorer {
public:
virtual ~ObjectExplorer() {}
virtual void OnBool(bool b) = 0;
virtual void OnString(std::string_view str) = 0;
virtual void OnDouble(double d) = 0;
virtual void OnInt(int64_t val) = 0;
virtual void OnArrayStart(unsigned len) = 0;
virtual void OnArrayEnd() = 0;
virtual void OnNil() = 0;
virtual void OnStatus(std::string_view str) = 0;
virtual void OnError(std::string_view str) = 0;
};
class Interpreter {
public:
Interpreter();
@ -32,11 +47,16 @@ class Interpreter {
// Returns: true if the call succeeded, otherwise fills error and returns false.
bool RunFunction(const char* f_id, std::string* err);
bool Execute(std::string_view body, char f_id[43], std::string* err);
bool Serialize(ObjectExplorer* serializer, std::string* err);
// fp must point to buffer with at least 43 chars.
// fp[42] will be set to '\0'.
static void Fingerprint(std::string_view body, char* fp);
private:
bool AddInternal(const char* f_id, std::string_view body, std::string* result);
lua_State* lua_;
};

View File

@ -9,6 +9,7 @@ extern "C" {
#include <lua.h>
}
#include <absl/strings/str_cat.h>
#include <gmock/gmock.h>
#include "base/gtest.h"
@ -17,6 +18,58 @@ extern "C" {
namespace dfly {
using namespace std;
class TestSerializer : public ObjectExplorer {
public:
string res;
void OnBool(bool b) final {
absl::StrAppend(&res, "bool(", b, ") ");
OnItem();
}
void OnString(std::string_view str) final {
absl::StrAppend(&res, "str(", str, ") ");
OnItem();
}
void OnDouble(double d) final {
absl::StrAppend(&res, "d(", d, ") ");
OnItem();
}
void OnInt(int64_t val) final {
absl::StrAppend(&res, "i(", val, ") ");
OnItem();
}
void OnArrayStart(unsigned len) final {
absl::StrAppend(&res, "[");
}
void OnArrayEnd() final {
if (res.back() == ' ')
res.pop_back();
absl::StrAppend(&res, "] ");
}
void OnNil() final {
absl::StrAppend(&res, "nil ");
}
void OnStatus(std::string_view str) {
absl::StrAppend(&res, "status(", str, ") ");
}
void OnError(std::string_view str) {
absl::StrAppend(&res, "err(", str, ") ");
}
private:
void OnItem() {
}
};
class InterpreterTest : public ::testing::Test {
protected:
InterpreterTest() {
@ -31,9 +84,27 @@ class InterpreterTest : public ::testing::Test {
CHECK_EQ(0, lua_pcall(lua(), 0, 0, 0));
}
bool Serialize(string* err) {
ser_.res.clear();
bool res = intptr_.Serialize(&ser_, err);
if (!ser_.res.empty())
ser_.res.pop_back();
return res;
}
bool Execute(string_view script);
Interpreter intptr_;
TestSerializer ser_;
string error_;
};
bool InterpreterTest::Execute(string_view script) {
char buf[48];
return intptr_.Execute(script, buf, &error_) && Serialize(&error_);
}
TEST_F(InterpreterTest, Basic) {
RunInline(R"(
function foo(n)
@ -64,4 +135,35 @@ TEST_F(InterpreterTest, Add) {
EXPECT_EQ(0, lua_gettop(lua()));
}
// Test cases taken from scripting.tcl
TEST_F(InterpreterTest, Execute) {
EXPECT_TRUE(Execute("return 42"));
EXPECT_EQ("i(42)", ser_.res);
EXPECT_TRUE(Execute("return 'hello'"));
EXPECT_EQ("str(hello)", ser_.res);
// Breaks compatibility.
EXPECT_TRUE(Execute("return 100.5"));
EXPECT_EQ("d(100.5)", ser_.res);
EXPECT_TRUE(Execute("return true"));
EXPECT_EQ("bool(1)", ser_.res);
EXPECT_TRUE(Execute("return false"));
EXPECT_EQ("bool(0)", ser_.res);
EXPECT_TRUE(Execute("return {ok='fine'}"));
EXPECT_EQ("status(fine)", ser_.res);
EXPECT_TRUE(Execute("return {err= 'bla'}"));
EXPECT_EQ("err(bla)", ser_.res);
EXPECT_TRUE(Execute("return {1, 2, nil, 3}"));
EXPECT_EQ("[i(1) i(2) nil i(3)]", ser_.res);
EXPECT_TRUE(Execute("return {1,2,3,'ciao',{1,2}}"));
EXPECT_EQ("[i(1) i(2) i(3) str(ciao) [i(1) i(2)]]", ser_.res);
}
} // namespace dfly

View File

@ -213,6 +213,11 @@ TEST_F(DflyEngineTest, FlushDb) {
ASSERT_FALSE(service_->IsShardSetLocked());
}
TEST_F(DflyEngineTest, Eval) {
auto resp = Run({"eval", "return 42", "0"});
EXPECT_THAT(resp[0], IntArg(42));
}
// TODO: to test transactions with a single shard since then all transactions become local.
// To consider having a parameter in dragonfly engine controlling number of shards
// unconditionally from number of cpus. TO TEST BLPOP under multi for single/multi argument case.

View File

@ -49,6 +49,51 @@ metrics::CounterFamily cmd_req("requests_total", "Number of served redis request
constexpr size_t kMaxThreadSize = 1024;
class EvalSerializer : public ObjectExplorer {
public:
EvalSerializer(ReplyBuilder* rb) : rb_(rb) {
}
void OnBool(bool b) final {
LOG(FATAL) << "TBD";
}
void OnString(std::string_view str) final {
LOG(FATAL) << "TBD";
}
void OnDouble(double d) final {
LOG(FATAL) << "TBD";
}
void OnInt(int64_t val) final {
rb_->SendLong(val);
}
void OnArrayStart(unsigned len) final {
LOG(FATAL) << "TBD";
}
void OnArrayEnd() final {
LOG(FATAL) << "TBD";
}
void OnNil() final {
LOG(FATAL) << "TBD";
}
void OnStatus(std::string_view str) {
LOG(FATAL) << "TBD";
}
void OnError(std::string_view str) {
LOG(FATAL) << "TBD";
}
private:
ReplyBuilder* rb_;
};
} // namespace
Service::Service(ProactorPool* pp) : pp_(*pp), shard_set_(pp), server_family_(this) {
@ -98,7 +143,7 @@ void Service::Shutdown() {
request_latency_usec.Shutdown();
ping_qps.Shutdown();
pp_.AwaitFiberOnAll([](ProactorBase* pb) { ServerState::tlocal()->Shutdown(); });
pp_.AwaitFiberOnAll([](ProactorBase* pb) { ServerState::tlocal()->Shutdown(); });
// to shutdown all the runtime components that depend on EngineShard.
server_family_.Shutdown();
@ -290,9 +335,35 @@ void Service::Multi(CmdArgList args, ConnectionContext* cntx) {
}
void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
Interpreter& script = ServerState::tlocal()->GetInterpreter();
script.lua();
return cntx->SendOk();
string_view body = ArgS(args, 1);
string_view num_keys_str = ArgS(args, 2);
int32_t num_keys;
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0) {
return cntx->SendError(kInvalidIntErr);
}
if (unsigned(num_keys) > args.size() - 3) {
return cntx->SendError("Number of keys can't be greater than number of args");
}
ServerState* ss = ServerState::tlocal();
lock_guard lk(ss->interpreter_mutex);
Interpreter& script = ss->GetInterpreter();
string error;
char f_id[48];
bool success = script.Execute(body, f_id, &error);
if (success) {
EvalSerializer ser{cntx};
string error;
if (!script.Serialize(&ser, &error)) {
cntx->SendError(error);
}
} else {
string resp = absl::StrCat("Error running script (call to ", f_id, "): ", error);
return cntx->SendError(resp);
}
}
void Service::Exec(CmdArgList args, ConnectionContext* cntx) {

View File

@ -4,12 +4,13 @@
#pragma once
#include <boost/fiber/mutex.hpp>
#include <optional>
#include <vector>
#include "core/interpreter.h"
#include "server/common_types.h"
#include "server/global_state.h"
#include "core/interpreter.h"
namespace dfly {
@ -17,7 +18,7 @@ namespace dfly {
// state around engine shards while the former represents coordinator/connection state.
// There may be threads that handle engine shards but not IO, there may be threads that handle IO
// but not engine shards and there can be threads that handle both. This class is present only
// for threads that handle IO and owne coordination fibers.
// for threads that handle IO and manage incoming connections.
class ServerState { // public struct - to allow initialization.
ServerState(const ServerState&) = delete;
void operator=(const ServerState&) = delete;
@ -53,11 +54,20 @@ class ServerState { // public struct - to allow initialization.
return live_transactions_;
}
GlobalState::S gstate() const { return gstate_;}
void set_gstate(GlobalState::S s) { gstate_ = s; }
GlobalState::S gstate() const {
return gstate_;
}
void set_gstate(GlobalState::S s) {
gstate_ = s;
}
Interpreter& GetInterpreter();
// We have interpreter per thread, not per connection.
// Since we might preempt into different fibers when operating on interpreter
// we must lock it until we finish using it per request.
::boost::fibers::mutex interpreter_mutex;
private:
int64_t live_transactions_ = 0;
std::optional<Interpreter> interpreter_;