Allow calling a redis function from interpreter.

Introduce a translator that converts redis response to lua result coming from redis.call
Add tests.
This commit is contained in:
Roman Gershman 2022-02-03 00:31:26 +02:00
parent 4cff2d8b7d
commit 067e1c3b62
11 changed files with 302 additions and 58 deletions

3
TODO.md Normal file
View File

@ -0,0 +1,3 @@
1. To move lua_project to dragonfly from helio
2. To limit lua stack to something reasonable like 4096.
3. To inject our own allocator to lua to track its memory.

View File

@ -1,6 +1,6 @@
add_library(dfly_core compact_object.cc dragonfly_core.cc interpreter.cc
tx_queue.cc)
cxx_link(dfly_core base absl::flat_hash_map redis_lib TRDP::lua crypto)
cxx_link(dfly_core base absl::flat_hash_map absl::str_format redis_lib TRDP::lua crypto)
cxx_test(dfly_core_test dfly_core LABELS DFLY)
cxx_test(compact_object_test dfly_core LABELS DFLY)

View File

@ -978,6 +978,7 @@ void Segment<Key, Value, Policy>::Split(HFunc&& hfn, Segment* dest_right) {
auto it = dest_right->InsertUniq(std::forward<Key_t>(key),
std::forward<Value_t>(Value(i, slot)), hash);
(void)it;
if constexpr (USE_VERSION) {
// Maintaining consistent versioning.
dest_right->bucket_[it.index].SetVersion(it.slot, bucket_[i].GetVersion(slot));
@ -1010,7 +1011,7 @@ void Segment<Key, Value, Policy>::Split(HFunc&& hfn, Segment* dest_right) {
invalid_mask |= (1u << slot);
auto it = dest_right->InsertUniq(std::forward<Key_t>(Key(bid, slot)),
std::forward<Value_t>(Value(bid, slot)), hash);
(void)it;
if constexpr (USE_VERSION) {
// Update the version in the destination bucket.
dest_right->bucket_[it.index].SetVersion(it.slot, stash.GetVersion(slot));

View File

@ -16,6 +16,8 @@ extern "C" {
#include <lualib.h>
}
#include <absl/strings/str_format.h>
#include "base/logging.h"
namespace dfly {
@ -23,6 +25,104 @@ using namespace std;
namespace {
class RedisTranslator : public ObjectExplorer {
public:
RedisTranslator(lua_State* lua) : lua_(lua) {
}
void OnBool(bool b) final;
void OnString(std::string_view str) final;
void OnDouble(double d) final;
void OnInt(int64_t val) final;
void OnArrayStart(unsigned len) final;
void OnArrayEnd() final;
void OnNil() final;
void OnStatus(std::string_view str) final;
void OnError(std::string_view str) final;
private:
void ArrayPre() {
/*if (!array_index_.empty()) {
lua_pushnumber(lua_, array_index_.back());
array_index_.back()++;
}*/
}
void ArrayPost() {
if (!array_index_.empty()) {
lua_rawseti(lua_, -2, array_index_.back()++); /* set table at key `i' */
// lua_settable(lua_, -3);
}
}
vector<unsigned> array_index_;
lua_State* lua_;
};
void RedisTranslator::OnBool(bool b) {
CHECK(!b) << "Only false (nil) supported";
ArrayPre();
lua_pushboolean(lua_, 0);
ArrayPost();
}
void RedisTranslator::OnString(std::string_view str) {
ArrayPre();
lua_pushlstring(lua_, str.data(), str.size());
ArrayPost();
}
// Doubles are not supported by Redis, however we can support them.
// Here is the use-case:
// local foo = redis.call('zscore', 'myzset', 'one')
// assert(type(foo) == "number")
void RedisTranslator::OnDouble(double d) {
ArrayPre();
lua_pushnumber(lua_, d);
ArrayPost();
}
void RedisTranslator::OnInt(int64_t val) {
ArrayPre();
lua_pushinteger(lua_, val);
ArrayPost();
}
void RedisTranslator::OnNil() {
ArrayPre();
lua_pushboolean(lua_, 0);
ArrayPost();
}
void RedisTranslator::OnStatus(std::string_view str) {
CHECK(array_index_.empty()) << "unexpected status";
lua_newtable(lua_);
lua_pushstring(lua_, "ok");
lua_pushlstring(lua_, str.data(), str.size());
lua_settable(lua_, -3);
}
void RedisTranslator::OnError(std::string_view str) {
CHECK(array_index_.empty()) << "unexpected error";
lua_newtable(lua_);
lua_pushstring(lua_, "err");
lua_pushlstring(lua_, str.data(), str.size());
lua_settable(lua_, -3);
}
void RedisTranslator::OnArrayStart(unsigned len) {
ArrayPre();
lua_newtable(lua_);
array_index_.push_back(1);
}
void RedisTranslator::OnArrayEnd() {
CHECK(!array_index_.empty());
DCHECK(lua_istable(lua_, -1));
array_index_.pop_back();
ArrayPost();
}
void RunSafe(lua_State* lua, string_view buf, const char* name) {
CHECK_EQ(0, luaL_loadbuffer(lua, buf.data(), buf.size(), name));
int err = lua_pcall(lua, 0, 0, 0);
@ -388,35 +488,32 @@ bool Interpreter::Serialize(ObjectExplorer* serializer, std::string* error) {
}
break;
case LUA_TTABLE: {
unsigned len = lua_rawlen(lua_, -1);
if (len > 0) { // array
serializer->OnArrayStart(len);
for (unsigned i = 0; i < len; ++i) {
t = lua_rawgeti(lua_, -1, i + 1); // push table element
// TODO: we should make sure that we have enough stack space
// to traverse each object. This can be done as a dry-run before doing real serialization.
// Once we are sure we are safe we can simplify the serialization flow and
// remove the error factor.
CHECK(Serialize(serializer, error)); // pops the element
}
serializer->OnArrayEnd();
} else {
auto fres = FetchKey(lua_, "err");
if (fres && *fres == LUA_TSTRING) {
serializer->OnError(TopSv(lua_));
lua_pop(lua_, 1);
break;
}
fres = FetchKey(lua_, "ok");
if (fres && *fres == LUA_TSTRING) {
serializer->OnStatus(TopSv(lua_));
lua_pop(lua_, 1);
break;
}
serializer->OnError("TBD");
auto fres = FetchKey(lua_, "err");
if (fres && *fres == LUA_TSTRING) {
serializer->OnError(TopSv(lua_));
lua_pop(lua_, 1);
break;
}
fres = FetchKey(lua_, "ok");
if (fres && *fres == LUA_TSTRING) {
serializer->OnStatus(TopSv(lua_));
lua_pop(lua_, 1);
break;
}
unsigned len = lua_rawlen(lua_, -1);
serializer->OnArrayStart(len);
for (unsigned i = 0; i < len; ++i) {
t = lua_rawgeti(lua_, -1, i + 1); // push table element
// TODO: we should make sure that we have enough stack space
// to traverse each object. This can be done as a dry-run before doing real serialization.
// Once we are sure we are safe we can simplify the serialization flow and
// remove the error factor.
CHECK(Serialize(serializer, error)); // pops the element
}
serializer->OnArrayEnd();
break;
}
case LUA_TNIL:
@ -446,6 +543,11 @@ int Interpreter::RedisGenericCommand(bool raise_error) {
return 1;
}
if (!redis_func_) {
PushError(lua_, "internal error - redis function not defined");
return raise_error ? RaiseError(lua_) : 1;
}
cmd_depth_++;
int argc = lua_gettop(lua_);
@ -457,14 +559,61 @@ int Interpreter::RedisGenericCommand(bool raise_error) {
return raise_error ? RaiseError(lua_) : 1;
}
// TODO: to prepare arguments.
size_t blob_len = 0;
char tmpbuf[64];
for (int j = 0; j < argc; j++) {
unsigned idx = j + 1;
if (lua_isinteger(lua_, idx)) {
absl::AlphaNum an(lua_tointeger(lua_, idx));
blob_len += an.size();
} else if (lua_isnumber(lua_, idx)) {
// fmt_len does not include '\0'.
int fmt_len = absl::SNPrintF(tmpbuf, sizeof(tmpbuf), "%.17g", lua_tonumber(lua_, idx));
CHECK_GT(fmt_len, 0);
blob_len += fmt_len;
} else if (lua_isstring(lua_, idx)) {
blob_len += lua_rawlen(lua_, idx); // lua_rawlen does not include '\0'.
} else {
PushError(lua_, "Lua redis() command arguments must be strings or integers");
cmd_depth_--;
return raise_error ? RaiseError(lua_) : 1;
}
}
// backing storage.
unique_ptr<char[]> blob(new char[blob_len + 8]); // 8 safety.
vector<absl::Span<char>> cmdargs;
char* cur = blob.get();
char* end = cur + blob_len;
for (int j = 0; j < argc; j++) {
unsigned idx = j + 1;
size_t len;
if (lua_isinteger(lua_, idx)) {
char* next = absl::numbers_internal::FastIntToBuffer(lua_tointeger(lua_, idx), cur);
len = next - cur;
} else if (lua_isnumber(lua_, idx)) {
int fmt_len = absl::SNPrintF(cur, end - cur, "%.17g", lua_tonumber(lua_, idx));
CHECK_GT(fmt_len, 0);
len = fmt_len;
} else if (lua_isstring(lua_, idx)) {
len = lua_rawlen(lua_, idx);
memcpy(cur, lua_tostring(lua_, idx), len); // copy \0 as well.
}
cmdargs.emplace_back(cur, len);
cur += len;
}
/* Pop all arguments from the stack, we do not need them anymore
* and this way we guaranty we will have room on the stack for the result. */
lua_pop(lua_, argc);
RedisTranslator translator(lua_);
redis_func_(MutSliceSpan{cmdargs}, &translator);
DCHECK_EQ(1, lua_gettop(lua_));
cmd_depth_--;
lua_pushinteger(lua_, 42);
return 1;
}

View File

@ -4,6 +4,9 @@
#pragma once
#include <absl/types/span.h>
#include <functional>
#include <string_view>
typedef struct lua_State lua_State;
@ -11,8 +14,9 @@ typedef struct lua_State lua_State;
namespace dfly {
class ObjectExplorer {
public:
virtual ~ObjectExplorer() {}
public:
virtual ~ObjectExplorer() {
}
virtual void OnBool(bool b) = 0;
virtual void OnString(std::string_view str) = 0;
@ -54,16 +58,25 @@ class Interpreter {
// fp[42] will be set to '\0'.
static void Fingerprint(std::string_view body, char* fp);
using MutableSlice = absl::Span<char>;
using MutSliceSpan = absl::Span<MutableSlice>;
using RedisFunc = std::function<void(MutSliceSpan, ObjectExplorer*)>;
template<typename U> void SetRedisFunc(U&& u) {
redis_func_ = std::forward<U>(u);
}
private:
bool AddInternal(const char* f_id, std::string_view body, std::string* result);
int RedisGenericCommand(bool raise_error);
static int RedisCallCommand(lua_State *lua);
static int RedisPCallCommand(lua_State *lua);
static int RedisCallCommand(lua_State* lua);
static int RedisPCallCommand(lua_State* lua);
lua_State* lua_;
unsigned cmd_depth_ = 0;
RedisFunc redis_func_;
};
} // namespace dfly

View File

@ -24,22 +24,18 @@ class TestSerializer : public ObjectExplorer {
void OnBool(bool b) final {
absl::StrAppend(&res, "bool(", b, ") ");
OnItem();
}
void OnString(std::string_view str) final {
absl::StrAppend(&res, "str(", str, ") ");
OnItem();
}
void OnDouble(double d) final {
absl::StrAppend(&res, "d(", d, ") ");
OnItem();
}
void OnInt(int64_t val) final {
absl::StrAppend(&res, "i(", val, ") ");
OnItem();
}
void OnArrayStart(unsigned len) final {
@ -64,10 +60,6 @@ class TestSerializer : public ObjectExplorer {
void OnError(std::string_view str) {
absl::StrAppend(&res, "err(", str, ") ");
}
private:
void OnItem() {
}
};
class InterpreterTest : public ::testing::Test {
@ -79,9 +71,9 @@ class InterpreterTest : public ::testing::Test {
return intptr_.lua();
}
void RunInline(string_view buf, const char* name) {
void RunInline(string_view buf, const char* name, unsigned num_results = 0) {
CHECK_EQ(0, luaL_loadbuffer(lua(), buf.data(), buf.size(), name));
CHECK_EQ(0, lua_pcall(lua(), 0, 0, 0));
CHECK_EQ(0, lua_pcall(lua(), 0, num_results, 0));
}
bool Serialize(string* err) {
@ -123,6 +115,30 @@ TEST_F(InterpreterTest, Basic) {
EXPECT_EQ(43, val1);
EXPECT_EQ(42, val2);
EXPECT_EQ(0, lua_gettop(lua()));
lua_pushstring(lua(), "foo");
EXPECT_EQ(3, lua_rawlen(lua(), 1));
lua_pop(lua(), 1);
RunInline("return {nil, 'b'}", "code2", 1);
ASSERT_EQ(1, lua_gettop(lua()));
LOG(INFO) << lua_typename(lua(), lua_type(lua(), -1));
ASSERT_TRUE(lua_istable(lua(), -1));
ASSERT_EQ(2, lua_rawlen(lua(), -1));
lua_len(lua(), -1);
ASSERT_EQ(2, lua_tointeger(lua(), -1));
lua_pop(lua(), 1);
lua_pushnil(lua());
while (lua_next(lua(), -2)) {
/* uses 'key' (at index -2) and 'value' (at index -1) */
int kt = lua_type(lua(), -2);
int vt = lua_type(lua(), -1);
LOG(INFO) << "k/v : " << lua_typename(lua(), kt) << "/" << lua_tonumber(lua(), -2) << " "
<< lua_typename(lua(), vt);
lua_pop(lua(), 1);
}
}
TEST_F(InterpreterTest, Add) {
@ -162,8 +178,62 @@ TEST_F(InterpreterTest, Execute) {
EXPECT_TRUE(Execute("return {1, 2, nil, 3}"));
EXPECT_EQ("[i(1) i(2) nil i(3)]", ser_.res);
EXPECT_TRUE(Execute("return {1,2,3,'ciao',{1,2}}"));
EXPECT_TRUE(Execute("return {1,2,3,'ciao', {1,2}}"));
EXPECT_EQ("[i(1) i(2) i(3) str(ciao) [i(1) i(2)]]", ser_.res);
}
TEST_F(InterpreterTest, Call) {
auto cb = [](Interpreter::MutSliceSpan span, ObjectExplorer* reply) {
CHECK_GE(span.size(), 1u);
string_view cmd{span[0].data(), span[0].size()};
if (cmd == "string") {
reply->OnString("foo");
} else if (cmd == "double") {
reply->OnDouble(3.1415);
} else if (cmd == "int") {
reply->OnInt(42);
} else if (cmd == "err") {
reply->OnError("myerr");
} else if (cmd == "status") {
reply->OnStatus("mystatus");
} else {
LOG(FATAL) << "Invalid param";
}
};
intptr_.SetRedisFunc(cb);
EXPECT_TRUE(Execute("local var = redis.call('string'); return {type(var), var}"));
EXPECT_EQ("[str(string) str(foo)]", ser_.res);
EXPECT_TRUE(Execute("local var = redis.call('double'); return {type(var), var}"));
EXPECT_EQ("[str(number) d(3.1415)]", ser_.res);
EXPECT_TRUE(Execute("local var = redis.call('int'); return {type(var), var}"));
EXPECT_EQ("[str(number) i(42)]", ser_.res);
EXPECT_TRUE(Execute("local var = redis.call('err'); return {type(var), var}"));
EXPECT_EQ("[str(table) err(myerr)]", ser_.res);
EXPECT_TRUE(Execute("local var = redis.call('status'); return {type(var), var}"));
EXPECT_EQ("[str(table) status(mystatus)]", ser_.res);
}
TEST_F(InterpreterTest, CallArray) {
auto cb = [](Interpreter::MutSliceSpan span, ObjectExplorer* reply) {
reply->OnArrayStart(2);
reply->OnArrayStart(1);
reply->OnArrayStart(2);
reply->OnNil();
reply->OnString("s2");
reply->OnArrayEnd();
reply->OnArrayEnd();
reply->OnInt(42);
reply->OnArrayEnd();
};
intptr_.SetRedisFunc(cb);
EXPECT_TRUE(Execute("local var = redis.call(''); return {type(var), var}"));
EXPECT_EQ("[str(table) [[[bool(0) str(s2)]] i(42)]]", ser_.res);
}
} // namespace dfly

View File

@ -25,9 +25,9 @@ using TxId = uint64_t;
using TxClock = uint64_t;
using ArgSlice = absl::Span<const std::string_view>;
using MutableStrSpan = absl::Span<char>;
using CmdArgList = absl::Span<MutableStrSpan>;
using CmdArgVec = std::vector<MutableStrSpan>;
using MutableSlice = absl::Span<char>;
using CmdArgList = absl::Span<MutableSlice>;
using CmdArgVec = std::vector<MutableSlice>;
constexpr DbIndex kInvalidDbId = DbIndex(-1);
constexpr ShardId kInvalidSid = ShardId(-1);
@ -73,11 +73,11 @@ inline std::string_view ArgS(CmdArgList args, size_t i) {
return std::string_view(arg.data(), arg.size());
}
inline MutableStrSpan ToMSS(absl::Span<uint8_t> span) {
return MutableStrSpan{reinterpret_cast<char*>(span.data()), span.size()};
inline MutableSlice ToMSS(absl::Span<uint8_t> span) {
return MutableSlice{reinterpret_cast<char*>(span.data()), span.size()};
}
inline void ToUpper(const MutableStrSpan* val) {
inline void ToUpper(const MutableSlice* val) {
for (auto& c : *val) {
c = absl::ascii_toupper(c);
}

View File

@ -449,7 +449,7 @@ auto Connection::FromArgs(RespVec args) -> Request* {
auto buf = args[i].GetBuf();
size_t s = buf.size();
memcpy(next, buf.data(), s);
req->args[i] = MutableStrSpan(next, s);
req->args[i] = MutableSlice(next, s);
next += s;
}

View File

@ -63,7 +63,7 @@ class Connection : public util::Connection {
std::unique_ptr<ConnectionContext> cc_;
struct Request {
absl::FixedArray<MutableStrSpan> args;
absl::FixedArray<MutableSlice> args;
absl::FixedArray<char> storage;
Request(size_t nargs, size_t capacity) : args(nargs), storage(capacity) {

View File

@ -98,6 +98,11 @@ class EvalSerializer : public ObjectExplorer {
ReplyBuilder* rb_;
};
void CallFromScript(CmdArgList args, ObjectExplorer* reply, ConnectionContext* cntx,
Service* service) {
reply->OnInt(42);
}
} // namespace
Service::Service(ProactorPool* pp) : pp_(*pp), shard_set_(pp), server_family_(this) {
@ -256,7 +261,7 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view value,
ConnectionContext* cntx) {
absl::InlinedVector<MutableStrSpan, 8> args;
absl::InlinedVector<MutableSlice, 8> args;
char cmd_name[16];
char set_opt[4] = {0};
@ -361,6 +366,10 @@ void Service::Eval(CmdArgList args, ConnectionContext* cntx) {
EvalSerializer ser{cntx};
string error;
script.SetRedisFunc([cntx](CmdArgList args, ObjectExplorer* reply) {
CallFromScript(args, reply, cntx, nullptr);
});
if (!script.Serialize(&ser, &error)) {
cntx->SendError(error);
}
@ -390,7 +399,7 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {
str_list.resize(scmd.cmd.size());
for (size_t i = 0; i < scmd.cmd.size(); ++i) {
string& s = scmd.cmd[i];
str_list[i] = MutableStrSpan{s.data(), s.size()};
str_list[i] = MutableSlice{s.data(), s.size()};
}
cntx->transaction->SetExecCmd(scmd.descr);

View File

@ -68,7 +68,6 @@ class Service {
void Exec(CmdArgList args, ConnectionContext* cntx);
void RegisterCommands();
base::VarzValue::Map GetVarzStats();