Implement serialization of most lua data types that are returned to caller
This commit is contained in:
parent
8fbe19d3c5
commit
d1f6f6d410
|
@ -8,6 +8,7 @@
|
|||
#include <openssl/sha.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <optional>
|
||||
|
||||
extern "C" {
|
||||
#include <lauxlib.h>
|
||||
|
@ -100,6 +101,20 @@ void ToHex(const uint8_t* src, char* dest) {
|
|||
dest[40] = '\0';
|
||||
}
|
||||
|
||||
string_view TopSv(lua_State* lua) {
|
||||
return string_view{lua_tostring(lua, -1), lua_rawlen(lua, -1)};
|
||||
}
|
||||
|
||||
optional<int> FetchKey(lua_State* lua, const char* key) {
|
||||
lua_pushstring(lua, key);
|
||||
int type = lua_gettable(lua, -2);
|
||||
if (type == LUA_TNIL) {
|
||||
lua_pop(lua, 1);
|
||||
return nullopt;
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Interpreter::Interpreter() {
|
||||
|
@ -127,24 +142,7 @@ bool Interpreter::AddFunction(string_view body, string* result) {
|
|||
char funcname[43];
|
||||
Fingerprint(body, funcname);
|
||||
|
||||
string script = absl::StrCat("function ", funcname, "() \n");
|
||||
absl::StrAppend(&script, body, "\nend");
|
||||
|
||||
int res = luaL_loadbuffer(lua_, script.data(), script.size(), "@user_script");
|
||||
if (res == 0) {
|
||||
res = lua_pcall(lua_, 0, 0, 0); // run func definition code
|
||||
}
|
||||
|
||||
if (res) {
|
||||
result->assign(lua_tostring(lua_, -1));
|
||||
lua_pop(lua_, 1); // Remove the error.
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
result->assign(funcname);
|
||||
|
||||
return true;
|
||||
return AddInternal(funcname, body, result);
|
||||
}
|
||||
|
||||
bool Interpreter::RunFunction(const char* f_id, std::string* error) {
|
||||
|
@ -167,4 +165,118 @@ bool Interpreter::RunFunction(const char* f_id, std::string* error) {
|
|||
return err == 0;
|
||||
}
|
||||
|
||||
bool Interpreter::Execute(string_view body, char f_id[43], string* error) {
|
||||
lua_getglobal(lua_, "__redis__err__handler");
|
||||
Fingerprint(body, f_id);
|
||||
|
||||
int type = lua_getglobal(lua_, f_id);
|
||||
if (type != LUA_TFUNCTION) {
|
||||
lua_pop(lua_, 1);
|
||||
if (!AddInternal(f_id, body, error))
|
||||
return false;
|
||||
type = lua_getglobal(lua_, f_id);
|
||||
CHECK_EQ(type, LUA_TFUNCTION);
|
||||
}
|
||||
|
||||
int err = lua_pcall(lua_, 0, 1, -2);
|
||||
if (err) {
|
||||
*error = lua_tostring(lua_, -1);
|
||||
}
|
||||
return err == 0;
|
||||
}
|
||||
|
||||
bool Interpreter::AddInternal(const char* f_id, string_view body, string* result) {
|
||||
string script = absl::StrCat("function ", f_id, "() \n");
|
||||
absl::StrAppend(&script, body, "\nend");
|
||||
|
||||
int res = luaL_loadbuffer(lua_, script.data(), script.size(), "@user_script");
|
||||
if (res == 0) {
|
||||
res = lua_pcall(lua_, 0, 0, 0); // run func definition code
|
||||
}
|
||||
|
||||
if (res) {
|
||||
result->assign(lua_tostring(lua_, -1));
|
||||
lua_pop(lua_, 1); // Remove the error.
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
result->assign(f_id);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
|
||||
// TODO: to get rid of this check or move it to the external function.
|
||||
// It does not make sense to do this check recursively and it complicates the flow
|
||||
// were in the middle of the serialization we could theoretically fail.
|
||||
if (!lua_checkstack(lua_, 4)) {
|
||||
/* Increase the Lua stack if needed to make sure there is enough room
|
||||
* to push 4 elements to the stack. On failure, return error.
|
||||
* Notice that we need, in the worst case, 4 elements because returning a map might
|
||||
* require push 4 elements to the Lua stack.*/
|
||||
error->assign("reached lua stack limit");
|
||||
lua_pop(lua_, 1); /* pop the element from the stack */
|
||||
return false;
|
||||
}
|
||||
|
||||
int t = lua_type(lua_, -1);
|
||||
bool res = true;
|
||||
|
||||
switch (t) {
|
||||
case LUA_TSTRING:
|
||||
serializer->OnString(TopSv(lua_));
|
||||
break;
|
||||
case LUA_TBOOLEAN:
|
||||
serializer->OnBool(lua_toboolean(lua_, -1));
|
||||
break;
|
||||
case LUA_TNUMBER:
|
||||
if (lua_isinteger(lua_, -1)) {
|
||||
serializer->OnInt(lua_tointeger(lua_, -1));
|
||||
} else {
|
||||
serializer->OnDouble(lua_tonumber(lua_, -1));
|
||||
}
|
||||
break;
|
||||
case LUA_TTABLE: {
|
||||
unsigned len = lua_rawlen(lua_, -1);
|
||||
if (len > 0) { // array
|
||||
serializer->OnArrayStart(len);
|
||||
for (unsigned i = 0; i < len; ++i) {
|
||||
t = lua_rawgeti(lua_, -1, i + 1); // push table element
|
||||
|
||||
// TODO: we should make sure that we have enough stack space
|
||||
// to traverse each object. This can be done as a dry-run before doing real serialization.
|
||||
// Once we are sure we are safe we can simplify the serialization flow and
|
||||
// remove the error factor.
|
||||
CHECK(Serialize(serializer, error)); // pops the element
|
||||
}
|
||||
serializer->OnArrayEnd();
|
||||
} else {
|
||||
auto fres = FetchKey(lua_, "err");
|
||||
if (fres && *fres == LUA_TSTRING) {
|
||||
serializer->OnError(TopSv(lua_));
|
||||
lua_pop(lua_, 1);
|
||||
break;
|
||||
}
|
||||
|
||||
fres = FetchKey(lua_, "ok");
|
||||
if (fres && *fres == LUA_TSTRING) {
|
||||
serializer->OnStatus(TopSv(lua_));
|
||||
lua_pop(lua_, 1);
|
||||
break;
|
||||
}
|
||||
serializer->OnError("TBD");
|
||||
}
|
||||
break;
|
||||
}
|
||||
case LUA_TNIL:
|
||||
serializer->OnNil();
|
||||
break;
|
||||
default:
|
||||
error->assign(absl::StrCat("Unsupported type ", t));
|
||||
}
|
||||
|
||||
lua_pop(lua_, 1);
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -10,6 +10,21 @@ typedef struct lua_State lua_State;
|
|||
|
||||
namespace dfly {
|
||||
|
||||
class ObjectExplorer {
|
||||
public:
|
||||
virtual ~ObjectExplorer() {}
|
||||
|
||||
virtual void OnBool(bool b) = 0;
|
||||
virtual void OnString(std::string_view str) = 0;
|
||||
virtual void OnDouble(double d) = 0;
|
||||
virtual void OnInt(int64_t val) = 0;
|
||||
virtual void OnArrayStart(unsigned len) = 0;
|
||||
virtual void OnArrayEnd() = 0;
|
||||
virtual void OnNil() = 0;
|
||||
virtual void OnStatus(std::string_view str) = 0;
|
||||
virtual void OnError(std::string_view str) = 0;
|
||||
};
|
||||
|
||||
class Interpreter {
|
||||
public:
|
||||
Interpreter();
|
||||
|
@ -32,11 +47,16 @@ class Interpreter {
|
|||
// Returns: true if the call succeeded, otherwise fills error and returns false.
|
||||
bool RunFunction(const char* f_id, std::string* err);
|
||||
|
||||
bool Execute(std::string_view body, char f_id[43], std::string* err);
|
||||
bool Serialize(ObjectExplorer* serializer, std::string* err);
|
||||
|
||||
// fp must point to buffer with at least 43 chars.
|
||||
// fp[42] will be set to '\0'.
|
||||
static void Fingerprint(std::string_view body, char* fp);
|
||||
|
||||
private:
|
||||
bool AddInternal(const char* f_id, std::string_view body, std::string* result);
|
||||
|
||||
lua_State* lua_;
|
||||
};
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ extern "C" {
|
|||
#include <lua.h>
|
||||
}
|
||||
|
||||
#include <absl/strings/str_cat.h>
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include "base/gtest.h"
|
||||
|
@ -17,6 +18,58 @@ extern "C" {
|
|||
namespace dfly {
|
||||
using namespace std;
|
||||
|
||||
class TestSerializer : public ObjectExplorer {
|
||||
public:
|
||||
string res;
|
||||
|
||||
void OnBool(bool b) final {
|
||||
absl::StrAppend(&res, "bool(", b, ") ");
|
||||
OnItem();
|
||||
}
|
||||
|
||||
void OnString(std::string_view str) final {
|
||||
absl::StrAppend(&res, "str(", str, ") ");
|
||||
OnItem();
|
||||
}
|
||||
|
||||
void OnDouble(double d) final {
|
||||
absl::StrAppend(&res, "d(", d, ") ");
|
||||
OnItem();
|
||||
}
|
||||
|
||||
void OnInt(int64_t val) final {
|
||||
absl::StrAppend(&res, "i(", val, ") ");
|
||||
OnItem();
|
||||
}
|
||||
|
||||
void OnArrayStart(unsigned len) final {
|
||||
absl::StrAppend(&res, "[");
|
||||
}
|
||||
|
||||
void OnArrayEnd() final {
|
||||
if (res.back() == ' ')
|
||||
res.pop_back();
|
||||
|
||||
absl::StrAppend(&res, "] ");
|
||||
}
|
||||
|
||||
void OnNil() final {
|
||||
absl::StrAppend(&res, "nil ");
|
||||
}
|
||||
|
||||
void OnStatus(std::string_view str) {
|
||||
absl::StrAppend(&res, "status(", str, ") ");
|
||||
}
|
||||
|
||||
void OnError(std::string_view str) {
|
||||
absl::StrAppend(&res, "err(", str, ") ");
|
||||
}
|
||||
|
||||
private:
|
||||
void OnItem() {
|
||||
}
|
||||
};
|
||||
|
||||
class InterpreterTest : public ::testing::Test {
|
||||
protected:
|
||||
InterpreterTest() {
|
||||
|
@ -31,9 +84,27 @@ class InterpreterTest : public ::testing::Test {
|
|||
CHECK_EQ(0, lua_pcall(lua(), 0, 0, 0));
|
||||
}
|
||||
|
||||
bool Serialize(string* err) {
|
||||
ser_.res.clear();
|
||||
bool res = intptr_.Serialize(&ser_, err);
|
||||
if (!ser_.res.empty())
|
||||
ser_.res.pop_back();
|
||||
return res;
|
||||
}
|
||||
|
||||
bool Execute(string_view script);
|
||||
|
||||
Interpreter intptr_;
|
||||
TestSerializer ser_;
|
||||
string error_;
|
||||
};
|
||||
|
||||
bool InterpreterTest::Execute(string_view script) {
|
||||
char buf[48];
|
||||
|
||||
return intptr_.Execute(script, buf, &error_) && Serialize(&error_);
|
||||
}
|
||||
|
||||
TEST_F(InterpreterTest, Basic) {
|
||||
RunInline(R"(
|
||||
function foo(n)
|
||||
|
@ -64,4 +135,35 @@ TEST_F(InterpreterTest, Add) {
|
|||
EXPECT_EQ(0, lua_gettop(lua()));
|
||||
}
|
||||
|
||||
// Test cases taken from scripting.tcl
|
||||
TEST_F(InterpreterTest, Execute) {
|
||||
EXPECT_TRUE(Execute("return 42"));
|
||||
EXPECT_EQ("i(42)", ser_.res);
|
||||
|
||||
EXPECT_TRUE(Execute("return 'hello'"));
|
||||
EXPECT_EQ("str(hello)", ser_.res);
|
||||
|
||||
// Breaks compatibility.
|
||||
EXPECT_TRUE(Execute("return 100.5"));
|
||||
EXPECT_EQ("d(100.5)", ser_.res);
|
||||
|
||||
EXPECT_TRUE(Execute("return true"));
|
||||
EXPECT_EQ("bool(1)", ser_.res);
|
||||
|
||||
EXPECT_TRUE(Execute("return false"));
|
||||
EXPECT_EQ("bool(0)", ser_.res);
|
||||
|
||||
EXPECT_TRUE(Execute("return {ok='fine'}"));
|
||||
EXPECT_EQ("status(fine)", ser_.res);
|
||||
|
||||
EXPECT_TRUE(Execute("return {err= 'bla'}"));
|
||||
EXPECT_EQ("err(bla)", ser_.res);
|
||||
|
||||
EXPECT_TRUE(Execute("return {1, 2, nil, 3}"));
|
||||
EXPECT_EQ("[i(1) i(2) nil i(3)]", ser_.res);
|
||||
|
||||
EXPECT_TRUE(Execute("return {1,2,3,'ciao',{1,2}}"));
|
||||
EXPECT_EQ("[i(1) i(2) i(3) str(ciao) [i(1) i(2)]]", ser_.res);
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
|
@ -213,6 +213,11 @@ TEST_F(DflyEngineTest, FlushDb) {
|
|||
ASSERT_FALSE(service_->IsShardSetLocked());
|
||||
}
|
||||
|
||||
TEST_F(DflyEngineTest, Eval) {
|
||||
auto resp = Run({"eval", "return 42", "0"});
|
||||
EXPECT_THAT(resp[0], IntArg(42));
|
||||
}
|
||||
|
||||
// TODO: to test transactions with a single shard since then all transactions become local.
|
||||
// To consider having a parameter in dragonfly engine controlling number of shards
|
||||
// unconditionally from number of cpus. TO TEST BLPOP under multi for single/multi argument case.
|
||||
|
|
|
@ -49,6 +49,51 @@ metrics::CounterFamily cmd_req("requests_total", "Number of served redis request
|
|||
|
||||
constexpr size_t kMaxThreadSize = 1024;
|
||||
|
||||
class EvalSerializer : public ObjectExplorer {
|
||||
public:
|
||||
EvalSerializer(ReplyBuilder* rb) : rb_(rb) {
|
||||
}
|
||||
|
||||
void OnBool(bool b) final {
|
||||
LOG(FATAL) << "TBD";
|
||||
}
|
||||
|
||||
void OnString(std::string_view str) final {
|
||||
LOG(FATAL) << "TBD";
|
||||
}
|
||||
|
||||
void OnDouble(double d) final {
|
||||
LOG(FATAL) << "TBD";
|
||||
}
|
||||
|
||||
void OnInt(int64_t val) final {
|
||||
rb_->SendLong(val);
|
||||
}
|
||||
|
||||
void OnArrayStart(unsigned len) final {
|
||||
LOG(FATAL) << "TBD";
|
||||
}
|
||||
|
||||
void OnArrayEnd() final {
|
||||
LOG(FATAL) << "TBD";
|
||||
}
|
||||
|
||||
void OnNil() final {
|
||||
LOG(FATAL) << "TBD";
|
||||
}
|
||||
|
||||
void OnStatus(std::string_view str) {
|
||||
LOG(FATAL) << "TBD";
|
||||
}
|
||||
|
||||
void OnError(std::string_view str) {
|
||||
LOG(FATAL) << "TBD";
|
||||
}
|
||||
|
||||
private:
|
||||
ReplyBuilder* rb_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Service::Service(ProactorPool* pp) : pp_(*pp), shard_set_(pp), server_family_(this) {
|
||||
|
@ -290,9 +335,35 @@ void Service::Multi(CmdArgList args, ConnectionContext* cntx) {
|
|||
}
|
||||
|
||||
void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
|
||||
Interpreter& script = ServerState::tlocal()->GetInterpreter();
|
||||
script.lua();
|
||||
return cntx->SendOk();
|
||||
string_view body = ArgS(args, 1);
|
||||
string_view num_keys_str = ArgS(args, 2);
|
||||
int32_t num_keys;
|
||||
|
||||
if (!absl::SimpleAtoi(num_keys_str, &num_keys) || num_keys < 0) {
|
||||
return cntx->SendError(kInvalidIntErr);
|
||||
}
|
||||
|
||||
if (unsigned(num_keys) > args.size() - 3) {
|
||||
return cntx->SendError("Number of keys can't be greater than number of args");
|
||||
}
|
||||
|
||||
ServerState* ss = ServerState::tlocal();
|
||||
lock_guard lk(ss->interpreter_mutex);
|
||||
Interpreter& script = ss->GetInterpreter();
|
||||
string error;
|
||||
char f_id[48];
|
||||
bool success = script.Execute(body, f_id, &error);
|
||||
if (success) {
|
||||
EvalSerializer ser{cntx};
|
||||
string error;
|
||||
|
||||
if (!script.Serialize(&ser, &error)) {
|
||||
cntx->SendError(error);
|
||||
}
|
||||
} else {
|
||||
string resp = absl::StrCat("Error running script (call to ", f_id, "): ", error);
|
||||
return cntx->SendError(resp);
|
||||
}
|
||||
}
|
||||
|
||||
void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
|
||||
|
|
|
@ -4,12 +4,13 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <boost/fiber/mutex.hpp>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "core/interpreter.h"
|
||||
#include "server/common_types.h"
|
||||
#include "server/global_state.h"
|
||||
#include "core/interpreter.h"
|
||||
|
||||
namespace dfly {
|
||||
|
||||
|
@ -17,7 +18,7 @@ namespace dfly {
|
|||
// state around engine shards while the former represents coordinator/connection state.
|
||||
// There may be threads that handle engine shards but not IO, there may be threads that handle IO
|
||||
// but not engine shards and there can be threads that handle both. This class is present only
|
||||
// for threads that handle IO and owne coordination fibers.
|
||||
// for threads that handle IO and manage incoming connections.
|
||||
class ServerState { // public struct - to allow initialization.
|
||||
ServerState(const ServerState&) = delete;
|
||||
void operator=(const ServerState&) = delete;
|
||||
|
@ -53,11 +54,20 @@ class ServerState { // public struct - to allow initialization.
|
|||
return live_transactions_;
|
||||
}
|
||||
|
||||
GlobalState::S gstate() const { return gstate_;}
|
||||
void set_gstate(GlobalState::S s) { gstate_ = s; }
|
||||
GlobalState::S gstate() const {
|
||||
return gstate_;
|
||||
}
|
||||
void set_gstate(GlobalState::S s) {
|
||||
gstate_ = s;
|
||||
}
|
||||
|
||||
Interpreter& GetInterpreter();
|
||||
|
||||
// We have interpreter per thread, not per connection.
|
||||
// Since we might preempt into different fibers when operating on interpreter
|
||||
// we must lock it until we finish using it per request.
|
||||
::boost::fibers::mutex interpreter_mutex;
|
||||
|
||||
private:
|
||||
int64_t live_transactions_ = 0;
|
||||
std::optional<Interpreter> interpreter_;
|
||||
|
|
Loading…
Reference in New Issue