From 07df3f2b9534f86dad7392bda5fcc9485861038a Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Sat, 5 Feb 2022 18:20:06 +0200 Subject: [PATCH] Simplify serialization logic in Interpreter --- core/interpreter.cc | 142 ++++++++++++++++++--------------------- core/interpreter.h | 14 +++- core/interpreter_test.cc | 42 +++++++++--- patches/lua-v5.4.4.patch | 17 ++++- server/conn_context.h | 8 ++- server/main_service.cc | 62 +++++++++++------ 6 files changed, 173 insertions(+), 112 deletions(-) diff --git a/core/interpreter.cc b/core/interpreter.cc index b2bb456..5b54c41 100644 --- a/core/interpreter.cc +++ b/core/interpreter.cc @@ -155,40 +155,6 @@ void SetGlobalArrayInternal(lua_State* lua, const char* name, Interpreter::MutSl lua_setglobal(lua, name); } -#if 0 -/* - * Save the give pointer on Lua registry, used to save the Lua context and - * function context so we can retrieve them from lua_State. - */ -void SaveOnRegistry(lua_State* lua, const char* name, void* ptr) { - lua_pushstring(lua, name); - if (ptr) { - lua_pushlightuserdata(lua, ptr); - } else { - lua_pushnil(lua); - } - lua_settable(lua, LUA_REGISTRYINDEX); -} - -/* - * Get a saved pointer from registry - */ -void* GetFromRegistry(lua_State* lua, const char* name) { - lua_pushstring(lua, name); - lua_gettable(lua, LUA_REGISTRYINDEX); - - /* must be light user data */ - DCHECK(lua_islightuserdata(lua, -1)); - - void* ptr = (void*)lua_topointer(lua, -1); - DCHECK(ptr); - - /* pops the value */ - lua_pop(lua, 1); - - return ptr; -} -#endif /* This function is used in order to push an error on the Lua stack in the * format used by redis.pcall to return errors, which is a lua table @@ -379,6 +345,7 @@ Interpreter::Interpreter() { /* Finally set the table as 'redis' global var. */ lua_setglobal(lua_, "redis"); + CHECK(lua_checkstack(lua_, 64)); } Interpreter::~Interpreter() { @@ -461,36 +428,19 @@ void Interpreter::SetGlobalArray(const char* name, MutSliceSpan args) { SetGlobalArrayInternal(lua_, name, args); } -/* -bool Interpreter::Execute(string_view body, char f_id[41], string* error) { - lua_getglobal(lua_, "__redis__err__handler"); - char fname[43]; - - fname[0] = 'f'; - fname[1] = '_'; - FuncSha1(body, f_id); - memcpy(fname + 2, f_id, 41); - - int type = lua_getglobal(lua_, fname); - if (type == LUA_TNIL) { - lua_pop(lua_, 1); - if (!AddInternal(fname, body, error)) - return false; - - type = lua_getglobal(lua_, fname); - CHECK_EQ(type, LUA_TFUNCTION); - } else if (type != LUA_TFUNCTION) { +bool Interpreter::IsResultSafe() const { + int top = lua_gettop(lua_); + if (top >= 128) return false; - } - int err = lua_pcall(lua_, 0, 1, -2); - if (err) { - *error = lua_tostring(lua_, -1); - } + int t = lua_type(lua_, -1); + if (t != LUA_TTABLE) + return true; - return err == 0; + bool res = IsTableSafe(); + lua_settop(lua_, top); + return res; } -*/ bool Interpreter::AddInternal(const char* f_id, string_view body, string* error) { string script = absl::StrCat("function ", f_id, "() \n"); @@ -511,22 +461,58 @@ bool Interpreter::AddInternal(const char* f_id, string_view body, string* error) return true; } -bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) { - // TODO: to get rid of this check or move it to the external function. - // It does not make sense to do this check recursively and it complicates the flow - // were in the middle of the serialization we could theoretically fail. - if (!lua_checkstack(lua_, 4)) { - /* Increase the Lua stack if needed to make sure there is enough room - * to push 4 elements to the stack. On failure, return error. - * Notice that we need, in the worst case, 4 elements because returning a map might - * require push 4 elements to the Lua stack.*/ - error->assign("reached lua stack limit"); - lua_pop(lua_, 1); /* pop the element from the stack */ - return false; +bool Interpreter::IsTableSafe() const { + auto fres = FetchKey(lua_, "err"); + if (fres && *fres == LUA_TSTRING) { + return true; } + fres = FetchKey(lua_, "ok"); + if (fres && *fres == LUA_TSTRING) { + return true; + } + + vector> lens; + unsigned len = lua_rawlen(lua_, -1); + unsigned i = 0; + + // implement dfs traversal + while (true) { + while (i < len) { + DVLOG(1) << "Stack " << lua_gettop(lua_) << "/" << i << "/" << len; + int t = lua_rawgeti(lua_, -1, i + 1); // push table element + if (t == LUA_TTABLE) { + if (lens.size() >= 127) // reached depth 128 + return false; + + CHECK(lua_checkstack(lua_, 1)); + lens.emplace_back(i + 1, len); // save the parent state. + + // reset to iterate on the next table. + i = 0; + len = lua_rawlen(lua_, -1); + } else { + lua_pop(lua_, 1); // pop table element + ++i; + } + } + + if (lens.empty()) // exit criteria + break; + + // unwind to the state before we went down the stack. + tie(i, len) = lens.back(); + lens.pop_back(); + + lua_pop(lua_, 1); + }; + DCHECK_EQ(1, lua_gettop(lua_)); + + return true; +} + +void Interpreter::SerializeResult(ObjectExplorer* serializer) { int t = lua_type(lua_, -1); - bool res = true; switch (t) { case LUA_TSTRING: @@ -566,7 +552,7 @@ bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) { // to traverse each object. This can be done as a dry-run before doing real serialization. // Once we are sure we are safe we can simplify the serialization flow and // remove the error factor. - CHECK(Serialize(serializer, error)); // pops the element + SerializeResult(serializer); // pops the element } serializer->OnArrayEnd(); break; @@ -575,11 +561,15 @@ bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) { serializer->OnNil(); break; default: - error->assign(absl::StrCat("Unsupported type ", t)); + LOG(ERROR) << "Unsupported type " << lua_typename(lua_, t); + serializer->OnNil(); } lua_pop(lua_, 1); - return res; +} + +void Interpreter::ResetStack() { + lua_settop(lua_, 0); } // Returns number of results, which is always 1 in this case. diff --git a/core/interpreter.h b/core/interpreter.h index b828ea2..e80e0de 100644 --- a/core/interpreter.h +++ b/core/interpreter.h @@ -67,15 +67,22 @@ class Interpreter { RUN_ERR = 2, }; + void SetGlobalArray(const char* name, MutSliceSpan args); + // Runs already added function sha returned by a successful call to AddFunction(). // Returns: true if the call succeeded, otherwise fills error and returns false. // sha must be 40 char length. RunResult RunFunction(std::string_view sha, std::string* err); - void SetGlobalArray(const char* name, MutSliceSpan args); + // Checks whether the result is safe to serialize. + // Should fit 2 conditions: + // 1. Be the only value on the stack. + // 2. Should have depth of no more than 128. + bool IsResultSafe() const; - // bool Execute(std::string_view body, char f_id[41], std::string* err); - bool Serialize(ObjectExplorer* serializer, std::string* err); + void SerializeResult(ObjectExplorer* serializer); + + void ResetStack(); // fp must point to buffer with at least 41 chars. // fp[40] will be set to '\0'. @@ -98,6 +105,7 @@ class Interpreter { // Returns true if function was successfully added, // otherwise returns false and sets the error. bool AddInternal(const char* f_id, std::string_view body, std::string* error); + bool IsTableSafe() const; int RedisGenericCommand(bool raise_error); diff --git a/core/interpreter_test.cc b/core/interpreter_test.cc index 30bc49e..25bf758 100644 --- a/core/interpreter_test.cc +++ b/core/interpreter_test.cc @@ -76,14 +76,6 @@ class InterpreterTest : public ::testing::Test { CHECK_EQ(0, lua_pcall(lua(), 0, num_results, 0)); } - bool Serialize(string* err) { - ser_.res.clear(); - bool res = intptr_.Serialize(&ser_, err); - if (!ser_.res.empty()) - ser_.res.pop_back(); - return res; - } - void SetGlobalArray(const char* name, vector vec); bool Execute(string_view script); @@ -113,7 +105,12 @@ bool InterpreterTest::Execute(string_view script) { if (run_res != Interpreter::RUN_OK) { return false; } - return Serialize(&error_); + + ser_.res.clear(); + intptr_.SerializeResult(&ser_); + ser_.res.pop_back(); + + return true; } TEST_F(InterpreterTest, Basic) { @@ -160,6 +157,33 @@ TEST_F(InterpreterTest, Basic) { } } +TEST_F(InterpreterTest, Stack) { + RunInline(R"( +local x = {} +for i=1,127 do + x = {x} +end +return x +)", + "code1", 1); + + ASSERT_EQ(1, lua_gettop(lua())); + ASSERT_TRUE(intptr_.IsResultSafe()); + lua_pop(lua(), 1); + + RunInline(R"( +local x = {} +for i=1,128 do + x = {x} +end +return x +)", + "code1", 1); + + ASSERT_EQ(1, lua_gettop(lua())); + ASSERT_FALSE(intptr_.IsResultSafe()); +} + TEST_F(InterpreterTest, Add) { string res1, res2; diff --git a/patches/lua-v5.4.4.patch b/patches/lua-v5.4.4.patch index 117c821..83b2efa 100644 --- a/patches/lua-v5.4.4.patch +++ b/patches/lua-v5.4.4.patch @@ -1,5 +1,18 @@ +diff --git a/luaconf.h b/luaconf.h +index d42d14b7..75647e72 100644 +--- a/luaconf.h ++++ b/luaconf.h +@@ -731,7 +731,7 @@ + ** (It must fit into max(size_t)/32.) + */ + #if LUAI_IS32INT +-#define LUAI_MAXSTACK 1000000 ++#define LUAI_MAXSTACK 4096 + #else + #define LUAI_MAXSTACK 15000 + #endif diff --git a/makefile b/makefile -index d46e650c..85e6b637 100644 +index d46e650c..52d8d57b 100644 --- a/makefile +++ b/makefile @@ -66,9 +66,9 @@ LOCAL = $(TESTS) $(CWARNS) @@ -7,7 +20,7 @@ index d46e650c..85e6b637 100644 # enable Linux goodies -MYCFLAGS= $(LOCAL) -std=c99 -DLUA_USE_LINUX -DLUA_USE_READLINE -+MYCFLAGS= $(LOCAL) -std=c99 -DLUA_USE_LINUX ++MYCFLAGS= $(LOCAL) -std=c99 -g -O2 -DLUA_USE_LINUX MYLDFLAGS= $(LOCAL) -Wl,-E -MYLIBS= -ldl -lreadline +MYLIBS= -ldl diff --git a/server/conn_context.h b/server/conn_context.h index 5a410a7..d832bbc 100644 --- a/server/conn_context.h +++ b/server/conn_context.h @@ -34,7 +34,7 @@ struct ConnectionState { // Whether this connection belongs to replica, i.e. a dragonfly slave is connected to this // host (master) via this connection to sync from it. - REPL_CONNECTION = 2, + REPL_CONNECTION = 4, }; uint32_t mask = 0; // A bitmask of Mask values. @@ -46,6 +46,12 @@ struct ConnectionState { bool IsRunViaDispatch() const { return mask & ASYNC_DISPATCH; } + + // Lua-script related data. + struct Script { + bool is_write = true; + }; + std::optional