Simplify serialization logic in Interpreter

This commit is contained in:
Roman Gershman 2022-02-05 18:20:06 +02:00
parent c567a70244
commit 07df3f2b95
6 changed files with 173 additions and 112 deletions

View File

@ -155,40 +155,6 @@ void SetGlobalArrayInternal(lua_State* lua, const char* name, Interpreter::MutSl
lua_setglobal(lua, name);
}
#if 0
/*
* Save the give pointer on Lua registry, used to save the Lua context and
* function context so we can retrieve them from lua_State.
*/
void SaveOnRegistry(lua_State* lua, const char* name, void* ptr) {
lua_pushstring(lua, name);
if (ptr) {
lua_pushlightuserdata(lua, ptr);
} else {
lua_pushnil(lua);
}
lua_settable(lua, LUA_REGISTRYINDEX);
}
/*
* Get a saved pointer from registry
*/
void* GetFromRegistry(lua_State* lua, const char* name) {
lua_pushstring(lua, name);
lua_gettable(lua, LUA_REGISTRYINDEX);
/* must be light user data */
DCHECK(lua_islightuserdata(lua, -1));
void* ptr = (void*)lua_topointer(lua, -1);
DCHECK(ptr);
/* pops the value */
lua_pop(lua, 1);
return ptr;
}
#endif
/* This function is used in order to push an error on the Lua stack in the
* format used by redis.pcall to return errors, which is a lua table
@ -379,6 +345,7 @@ Interpreter::Interpreter() {
/* Finally set the table as 'redis' global var. */
lua_setglobal(lua_, "redis");
CHECK(lua_checkstack(lua_, 64));
}
Interpreter::~Interpreter() {
@ -461,36 +428,19 @@ void Interpreter::SetGlobalArray(const char* name, MutSliceSpan args) {
SetGlobalArrayInternal(lua_, name, args);
}
/*
bool Interpreter::Execute(string_view body, char f_id[41], string* error) {
lua_getglobal(lua_, "__redis__err__handler");
char fname[43];
fname[0] = 'f';
fname[1] = '_';
FuncSha1(body, f_id);
memcpy(fname + 2, f_id, 41);
int type = lua_getglobal(lua_, fname);
if (type == LUA_TNIL) {
lua_pop(lua_, 1);
if (!AddInternal(fname, body, error))
return false;
type = lua_getglobal(lua_, fname);
CHECK_EQ(type, LUA_TFUNCTION);
} else if (type != LUA_TFUNCTION) {
bool Interpreter::IsResultSafe() const {
int top = lua_gettop(lua_);
if (top >= 128)
return false;
}
int err = lua_pcall(lua_, 0, 1, -2);
if (err) {
*error = lua_tostring(lua_, -1);
}
int t = lua_type(lua_, -1);
if (t != LUA_TTABLE)
return true;
return err == 0;
bool res = IsTableSafe();
lua_settop(lua_, top);
return res;
}
*/
bool Interpreter::AddInternal(const char* f_id, string_view body, string* error) {
string script = absl::StrCat("function ", f_id, "() \n");
@ -511,22 +461,58 @@ bool Interpreter::AddInternal(const char* f_id, string_view body, string* error)
return true;
}
bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
// TODO: to get rid of this check or move it to the external function.
// It does not make sense to do this check recursively and it complicates the flow
// were in the middle of the serialization we could theoretically fail.
if (!lua_checkstack(lua_, 4)) {
/* Increase the Lua stack if needed to make sure there is enough room
* to push 4 elements to the stack. On failure, return error.
* Notice that we need, in the worst case, 4 elements because returning a map might
* require push 4 elements to the Lua stack.*/
error->assign("reached lua stack limit");
lua_pop(lua_, 1); /* pop the element from the stack */
return false;
bool Interpreter::IsTableSafe() const {
auto fres = FetchKey(lua_, "err");
if (fres && *fres == LUA_TSTRING) {
return true;
}
fres = FetchKey(lua_, "ok");
if (fres && *fres == LUA_TSTRING) {
return true;
}
vector<pair<unsigned, unsigned>> lens;
unsigned len = lua_rawlen(lua_, -1);
unsigned i = 0;
// implement dfs traversal
while (true) {
while (i < len) {
DVLOG(1) << "Stack " << lua_gettop(lua_) << "/" << i << "/" << len;
int t = lua_rawgeti(lua_, -1, i + 1); // push table element
if (t == LUA_TTABLE) {
if (lens.size() >= 127) // reached depth 128
return false;
CHECK(lua_checkstack(lua_, 1));
lens.emplace_back(i + 1, len); // save the parent state.
// reset to iterate on the next table.
i = 0;
len = lua_rawlen(lua_, -1);
} else {
lua_pop(lua_, 1); // pop table element
++i;
}
}
if (lens.empty()) // exit criteria
break;
// unwind to the state before we went down the stack.
tie(i, len) = lens.back();
lens.pop_back();
lua_pop(lua_, 1);
};
DCHECK_EQ(1, lua_gettop(lua_));
return true;
}
void Interpreter::SerializeResult(ObjectExplorer* serializer) {
int t = lua_type(lua_, -1);
bool res = true;
switch (t) {
case LUA_TSTRING:
@ -566,7 +552,7 @@ bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
// to traverse each object. This can be done as a dry-run before doing real serialization.
// Once we are sure we are safe we can simplify the serialization flow and
// remove the error factor.
CHECK(Serialize(serializer, error)); // pops the element
SerializeResult(serializer); // pops the element
}
serializer->OnArrayEnd();
break;
@ -575,11 +561,15 @@ bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
serializer->OnNil();
break;
default:
error->assign(absl::StrCat("Unsupported type ", t));
LOG(ERROR) << "Unsupported type " << lua_typename(lua_, t);
serializer->OnNil();
}
lua_pop(lua_, 1);
return res;
}
void Interpreter::ResetStack() {
lua_settop(lua_, 0);
}
// Returns number of results, which is always 1 in this case.

View File

@ -67,15 +67,22 @@ class Interpreter {
RUN_ERR = 2,
};
void SetGlobalArray(const char* name, MutSliceSpan args);
// Runs already added function sha returned by a successful call to AddFunction().
// Returns: true if the call succeeded, otherwise fills error and returns false.
// sha must be 40 char length.
RunResult RunFunction(std::string_view sha, std::string* err);
void SetGlobalArray(const char* name, MutSliceSpan args);
// Checks whether the result is safe to serialize.
// Should fit 2 conditions:
// 1. Be the only value on the stack.
// 2. Should have depth of no more than 128.
bool IsResultSafe() const;
// bool Execute(std::string_view body, char f_id[41], std::string* err);
bool Serialize(ObjectExplorer* serializer, std::string* err);
void SerializeResult(ObjectExplorer* serializer);
void ResetStack();
// fp must point to buffer with at least 41 chars.
// fp[40] will be set to '\0'.
@ -98,6 +105,7 @@ class Interpreter {
// Returns true if function was successfully added,
// otherwise returns false and sets the error.
bool AddInternal(const char* f_id, std::string_view body, std::string* error);
bool IsTableSafe() const;
int RedisGenericCommand(bool raise_error);

View File

@ -76,14 +76,6 @@ class InterpreterTest : public ::testing::Test {
CHECK_EQ(0, lua_pcall(lua(), 0, num_results, 0));
}
bool Serialize(string* err) {
ser_.res.clear();
bool res = intptr_.Serialize(&ser_, err);
if (!ser_.res.empty())
ser_.res.pop_back();
return res;
}
void SetGlobalArray(const char* name, vector<string> vec);
bool Execute(string_view script);
@ -113,7 +105,12 @@ bool InterpreterTest::Execute(string_view script) {
if (run_res != Interpreter::RUN_OK) {
return false;
}
return Serialize(&error_);
ser_.res.clear();
intptr_.SerializeResult(&ser_);
ser_.res.pop_back();
return true;
}
TEST_F(InterpreterTest, Basic) {
@ -160,6 +157,33 @@ TEST_F(InterpreterTest, Basic) {
}
}
TEST_F(InterpreterTest, Stack) {
RunInline(R"(
local x = {}
for i=1,127 do
x = {x}
end
return x
)",
"code1", 1);
ASSERT_EQ(1, lua_gettop(lua()));
ASSERT_TRUE(intptr_.IsResultSafe());
lua_pop(lua(), 1);
RunInline(R"(
local x = {}
for i=1,128 do
x = {x}
end
return x
)",
"code1", 1);
ASSERT_EQ(1, lua_gettop(lua()));
ASSERT_FALSE(intptr_.IsResultSafe());
}
TEST_F(InterpreterTest, Add) {
string res1, res2;

View File

@ -1,5 +1,18 @@
diff --git a/luaconf.h b/luaconf.h
index d42d14b7..75647e72 100644
--- a/luaconf.h
+++ b/luaconf.h
@@ -731,7 +731,7 @@
** (It must fit into max(size_t)/32.)
*/
#if LUAI_IS32INT
-#define LUAI_MAXSTACK 1000000
+#define LUAI_MAXSTACK 4096
#else
#define LUAI_MAXSTACK 15000
#endif
diff --git a/makefile b/makefile
index d46e650c..85e6b637 100644
index d46e650c..52d8d57b 100644
--- a/makefile
+++ b/makefile
@@ -66,9 +66,9 @@ LOCAL = $(TESTS) $(CWARNS)
@ -7,7 +20,7 @@ index d46e650c..85e6b637 100644
# enable Linux goodies
-MYCFLAGS= $(LOCAL) -std=c99 -DLUA_USE_LINUX -DLUA_USE_READLINE
+MYCFLAGS= $(LOCAL) -std=c99 -DLUA_USE_LINUX
+MYCFLAGS= $(LOCAL) -std=c99 -g -O2 -DLUA_USE_LINUX
MYLDFLAGS= $(LOCAL) -Wl,-E
-MYLIBS= -ldl -lreadline
+MYLIBS= -ldl

View File

@ -34,7 +34,7 @@ struct ConnectionState {
// Whether this connection belongs to replica, i.e. a dragonfly slave is connected to this
// host (master) via this connection to sync from it.
REPL_CONNECTION = 2,
REPL_CONNECTION = 4,
};
uint32_t mask = 0; // A bitmask of Mask values.
@ -46,6 +46,12 @@ struct ConnectionState {
bool IsRunViaDispatch() const {
return mask & ASYNC_DISPATCH;
}
// Lua-script related data.
struct Script {
bool is_write = true;
};
std::optional<Script> script_info;
};
class ConnectionContext {

View File

@ -99,6 +99,14 @@ class EvalSerializer : public ObjectExplorer {
RedisReplyBuilder* rb_;
};
bool IsSHA(string_view str) {
for (auto c : str) {
if (!absl::ascii_isxdigit(c))
return false;
}
return true;
}
} // namespace
Service::Service(ProactorPool* pp) : pp_(*pp), shard_set_(pp), server_family_(this) {
@ -189,7 +197,14 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
return;
}
bool is_write_cmd = cid->opt_mask() & CO::WRITE;
bool under_script = cntx->conn_state.script_info.has_value();
if (under_script && (cid->opt_mask() & CO::NOSCRIPT)) {
return (*cntx)->SendError("This Redis command is not allowed from script");
}
bool is_write_cmd =
(cid->opt_mask() & CO::WRITE) || (under_script && cntx->conn_state.script_info->is_write);
bool under_multi = cntx->conn_state.exec_state != ConnectionState::EXEC_INACTIVE && !is_trans_cmd;
if (!etl.is_master && is_write_cmd) {
@ -340,7 +355,8 @@ void Service::Multi(CmdArgList args, ConnectionContext* cntx) {
}
void Service::CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionContext* cntx) {
reply->OnInt(42);
// reply->OnInt(42);
DispatchCommand(std::move(args), cntx);
}
void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
@ -402,7 +418,6 @@ void Service::EvalSha(CmdArgList args, ConnectionContext* cntx) {
if (!exists) {
const char* body = (sha.size() == 40) ? server_family_.script_mgr()->Find(sha) : nullptr;
if (!body) {
return (*cntx)->SendError(kScriptNotFound);
}
@ -424,7 +439,8 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
ConnectionContext* cntx) {
DCHECK(!eval_args.sha.empty());
if (eval_args.sha.size() != 40) {
// Sanitizing the input to avoid code injection.
if (eval_args.sha.size() != 40 || !IsSHA(eval_args.sha)) {
return (*cntx)->SendError(kScriptNotFound);
}
@ -443,34 +459,38 @@ void Service::EvalInternal(const EvalArgs& eval_args, Interpreter* interpreter,
string error;
DCHECK(!cntx->conn_state.script_info); // we should not call eval from the script.
// TODO: to determine whether the script is RO by scanning all "redis.p?call" calls
// and checking whether all invocations consist of RO commands.
// we can do it once during script insertion into script mgr.
cntx->conn_state.script_info.emplace();
auto lk = interpreter->Lock();
interpreter->SetGlobalArray("KEYS", eval_args.keys);
interpreter->SetGlobalArray("ARGV", eval_args.args);
interpreter->SetRedisFunc(
[cntx, this](CmdArgList args, ObjectExplorer* reply) { CallFromScript(args, reply, cntx); });
bool success = false;
if (eval_args.sha.empty()) {
} else {
Interpreter::RunResult result = interpreter->RunFunction(eval_args.sha, &error);
if (result == Interpreter::RUN_ERR) {
return (*cntx)->SendError(error);
}
CHECK(result == Interpreter::RUN_OK);
success = true;
}
Interpreter::RunResult result = interpreter->RunFunction(eval_args.sha, &error);
if (success) {
EvalSerializer ser{static_cast<RedisReplyBuilder*>(cntx->reply_builder())};
string error;
cntx->conn_state.script_info.reset(); // reset script_info
if (!interpreter->Serialize(&ser, &error)) {
(*cntx)->SendError(error);
}
} else {
if (result == Interpreter::RUN_ERR) {
string resp = absl::StrCat("Error running script (call to ", eval_args.sha, "): ", error);
return (*cntx)->SendError(resp);
}
CHECK(result == Interpreter::RUN_OK);
EvalSerializer ser{static_cast<RedisReplyBuilder*>(cntx->reply_builder())};
if (!interpreter->IsResultSafe()) {
(*cntx)->SendError("reached lua stack limit");
} else {
interpreter->SerializeResult(&ser);
}
interpreter->ResetStack();
}
void Service::Exec(CmdArgList args, ConnectionContext* cntx) {