Adding memcached protocol support for dragonfly

This commit is contained in:
Roman Gershman 2021-11-23 12:39:35 +02:00
parent d4b708d33c
commit 5ebbfa5a64
14 changed files with 278 additions and 54 deletions

View File

@ -1,7 +1,8 @@
add_executable(dragonfly dfly_main.cc) add_executable(dragonfly dfly_main.cc)
cxx_link(dragonfly base dragonfly_lib) cxx_link(dragonfly base dragonfly_lib)
add_library(dragonfly_lib command_registry.cc config_flags.cc db_slice.cc dragonfly_listener.cc add_library(dragonfly_lib command_registry.cc config_flags.cc conn_context.cc db_slice.cc
dragonfly_listener.cc
dragonfly_connection.cc engine_shard_set.cc dragonfly_connection.cc engine_shard_set.cc
main_service.cc memcache_parser.cc main_service.cc memcache_parser.cc
redis_parser.cc resp_expr.cc reply_builder.cc) redis_parser.cc resp_expr.cc reply_builder.cc)

View File

@ -55,7 +55,7 @@ void CommandRegistry::Command(CmdArgList args, ConnectionContext* cntx) {
StrAppend(&resp, ":", cd.key_arg_step(), "\r\n"); StrAppend(&resp, ":", cd.key_arg_step(), "\r\n");
} }
cntx->SendDirect(resp); cntx->SendRespBlob(resp);
} }
namespace CO { namespace CO {

15
server/conn_context.cc Normal file
View File

@ -0,0 +1,15 @@
// Copyright 2021, Beeri 15. All rights reserved.
// Author: Roman Gershman (romange@gmail.com)
//
#include "server/conn_context.h"
#include "server/dragonfly_connection.h"
namespace dfly {
ConnectionContext::ConnectionContext(::io::Sink* stream, Connection* owner)
: ReplyBuilder(owner->protocol(), stream), owner_(owner) {
}
} // namespace dfly

View File

@ -14,14 +14,15 @@ class CommandId;
class ConnectionContext : public ReplyBuilder { class ConnectionContext : public ReplyBuilder {
public: public:
ConnectionContext(::io::Sink* stream, Connection* owner) : ReplyBuilder(stream), owner_(owner) { ConnectionContext(::io::Sink* stream, Connection* owner);
}
// TODO: to introduce proper accessors. // TODO: to introduce proper accessors.
const CommandId* cid = nullptr; const CommandId* cid = nullptr;
EngineShardSet* shard_set = nullptr; EngineShardSet* shard_set = nullptr;
Connection* owner() { return owner_;} Connection* owner() {
return owner_;
}
private: private:
Connection* owner_; Connection* owner_;

View File

@ -11,6 +11,7 @@
DEFINE_int32(http_port, 8080, "Http port."); DEFINE_int32(http_port, 8080, "Http port.");
DECLARE_uint32(port); DECLARE_uint32(port);
DECLARE_uint32(memcache_port);
using namespace util; using namespace util;
@ -24,7 +25,11 @@ void RunEngine(ProactorPool* pool, AcceptServer* acceptor, HttpListener<>* http)
service.RegisterHttp(http); service.RegisterHttp(http);
} }
acceptor->AddListener(FLAGS_port, new Listener{&service}); acceptor->AddListener(FLAGS_port, new Listener{Protocol::REDIS, &service});
if (FLAGS_memcache_port > 0) {
acceptor->AddListener(FLAGS_memcache_port, new Listener{Protocol::MEMCACHE, &service});
}
acceptor->Run(); acceptor->Run();
acceptor->Wait(); acceptor->Wait();

16
server/dfly_protocol.h Normal file
View File

@ -0,0 +1,16 @@
// Copyright 2021, Beeri 15. All rights reserved.
// Author: Roman Gershman (romange@gmail.com)
//
#pragma once
#include <cstdint>
namespace dfly {
enum class Protocol : uint8_t {
MEMCACHE = 1,
REDIS = 2
};
} // namespace dfly

View File

@ -11,6 +11,7 @@
#include "server/command_registry.h" #include "server/command_registry.h"
#include "server/conn_context.h" #include "server/conn_context.h"
#include "server/main_service.h" #include "server/main_service.h"
#include "server/memcache_parser.h"
#include "server/redis_parser.h" #include "server/redis_parser.h"
#include "util/fiber_sched_algo.h" #include "util/fiber_sched_algo.h"
#include "util/tls/tls_socket.h" #include "util/tls/tls_socket.h"
@ -69,8 +70,18 @@ struct Connection::Shutdown {
} }
}; };
Connection::Connection(Service* service, SSL_CTX* ctx) : service_(service), ctx_(ctx) { Connection::Connection(Protocol protocol, Service* service, SSL_CTX* ctx)
: service_(service), ctx_(ctx) {
protocol_ = protocol;
switch (protocol) {
case Protocol::REDIS:
redis_parser_.reset(new RedisParser); redis_parser_.reset(new RedisParser);
break;
case Protocol::MEMCACHE:
memcache_parser_.reset(new MemcacheParser);
break;
}
} }
Connection::~Connection() { Connection::~Connection() {
@ -143,7 +154,14 @@ void Connection::InputLoop(FiberSocketBase* peer) {
} }
io_buf.CommitWrite(*recv_sz); io_buf.CommitWrite(*recv_sz);
if (redis_parser_)
status = ParseRedis(&io_buf); status = ParseRedis(&io_buf);
else {
DCHECK(memcache_parser_);
status = ParseMemcache(&io_buf);
}
if (status == NEED_MORE) { if (status == NEED_MORE) {
status = OK; status = OK;
} else if (status != OK) { } else if (status != OK) {
@ -206,4 +224,44 @@ auto Connection::ParseRedis(base::IoBuf* io_buf) -> ParserStatus {
return ERROR; return ERROR;
} }
auto Connection::ParseMemcache(base::IoBuf* io_buf) -> ParserStatus {
MemcacheParser::Result result = MemcacheParser::OK;
uint32_t consumed = 0;
MemcacheParser::Command cmd;
string_view value;
do {
string_view str = ToSV(io_buf->InputBuffer());
result = memcache_parser_->Parse(str, &consumed, &cmd);
if (result != MemcacheParser::OK) {
io_buf->ConsumeInput(consumed);
break;
}
size_t total_len = consumed;
if (MemcacheParser::IsStoreCmd(cmd.type)) {
total_len += cmd.bytes_len + 2;
if (io_buf->InputLen() >= total_len) {
value = str.substr(consumed, cmd.bytes_len);
// TODO: dispatch.
} else {
return NEED_MORE;
}
}
service_->DispatchMC(cmd, value, cc_.get());
io_buf->ConsumeInput(total_len);
} while (!cc_->ec());
parser_error_ = result;
if (result == MemcacheParser::OK)
return OK;
if (result == MemcacheParser::INPUT_PENDING)
return NEED_MORE;
return ERROR;
}
} // namespace dfly } // namespace dfly

View File

@ -7,6 +7,7 @@
#include "util/connection.h" #include "util/connection.h"
#include "base/io_buf.h" #include "base/io_buf.h"
#include "server/dfly_protocol.h"
typedef struct ssl_ctx_st SSL_CTX; typedef struct ssl_ctx_st SSL_CTX;
@ -15,10 +16,11 @@ namespace dfly {
class ConnectionContext; class ConnectionContext;
class RedisParser; class RedisParser;
class Service; class Service;
class MemcacheParser;
class Connection : public util::Connection { class Connection : public util::Connection {
public: public:
Connection(Service* service, SSL_CTX* ctx); Connection(Protocol protocol, Service* service, SSL_CTX* ctx);
~Connection(); ~Connection();
using error_code = std::error_code; using error_code = std::error_code;
@ -28,6 +30,8 @@ class Connection : public util::Connection {
ShutdownHandle RegisterShutdownHook(ShutdownCb cb); ShutdownHandle RegisterShutdownHook(ShutdownCb cb);
void UnregisterShutdownHook(ShutdownHandle id); void UnregisterShutdownHook(ShutdownHandle id);
Protocol protocol() const { return protocol_;}
protected: protected:
void OnShutdown() override; void OnShutdown() override;
@ -39,14 +43,16 @@ class Connection : public util::Connection {
void InputLoop(util::FiberSocketBase* peer); void InputLoop(util::FiberSocketBase* peer);
ParserStatus ParseRedis(base::IoBuf* buf); ParserStatus ParseRedis(base::IoBuf* buf);
ParserStatus ParseMemcache(base::IoBuf* buf);
std::unique_ptr<RedisParser> redis_parser_; std::unique_ptr<RedisParser> redis_parser_;
std::unique_ptr<MemcacheParser> memcache_parser_;
Service* service_; Service* service_;
SSL_CTX* ctx_; SSL_CTX* ctx_;
std::unique_ptr<ConnectionContext> cc_; std::unique_ptr<ConnectionContext> cc_;
unsigned parser_error_ = 0; unsigned parser_error_ = 0;
Protocol protocol_;
struct Shutdown; struct Shutdown;
std::unique_ptr<Shutdown> shutdown_; std::unique_ptr<Shutdown> shutdown_;
}; };

View File

@ -81,7 +81,7 @@ static SSL_CTX* CreateSslCntx() {
return ctx; return ctx;
} }
Listener::Listener(Service* e) : engine_(e) { Listener::Listener(Protocol protocol, Service* e) : engine_(e), protocol_(protocol) {
if (FLAGS_tls) { if (FLAGS_tls) {
OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, NULL); OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, NULL);
ctx_ = CreateSslCntx(); ctx_ = CreateSslCntx();
@ -93,7 +93,7 @@ Listener::~Listener() {
} }
util::Connection* Listener::NewConnection(ProactorBase* proactor) { util::Connection* Listener::NewConnection(ProactorBase* proactor) {
return new Connection{engine_, ctx_}; return new Connection{protocol_, engine_, ctx_};
} }
void Listener::PreShutdown() { void Listener::PreShutdown() {

View File

@ -5,6 +5,7 @@
#pragma once #pragma once
#include "util/listener_interface.h" #include "util/listener_interface.h"
#include "server/dfly_protocol.h"
typedef struct ssl_ctx_st SSL_CTX; typedef struct ssl_ctx_st SSL_CTX;
@ -14,7 +15,7 @@ class Service;
class Listener : public util::ListenerInterface { class Listener : public util::ListenerInterface {
public: public:
Listener(Service*); Listener(Protocol protocol, Service*);
~Listener(); ~Listener();
private: private:
@ -28,6 +29,7 @@ class Listener : public util::ListenerInterface {
Service* engine_; Service* engine_;
std::atomic_uint32_t next_id_{0}; std::atomic_uint32_t next_id_{0};
Protocol protocol_;
SSL_CTX* ctx_ = nullptr; SSL_CTX* ctx_ = nullptr;
}; };

View File

@ -16,6 +16,7 @@
#include "util/varz.h" #include "util/varz.h"
DEFINE_uint32(port, 6380, "Redis port"); DEFINE_uint32(port, 6380, "Redis port");
DEFINE_uint32(memcache_port, 0, "Memcached port");
namespace std { namespace std {
@ -127,6 +128,47 @@ void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) {
request_latency_usec.IncBy(cmd_str, (end_usec - start_usec) / 1000); request_latency_usec.IncBy(cmd_str, (end_usec - start_usec) / 1000);
} }
void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view value,
ConnectionContext* cntx) {
absl::InlinedVector<MutableStrSpan, 8> args;
char cmd_name[16];
char set_opt[4] = {0};
switch (cmd.type) {
case MemcacheParser::REPLACE:
strcpy(cmd_name, "SET");
strcpy(set_opt, "XX");
break;
case MemcacheParser::SET:
strcpy(cmd_name, "SET");
break;
case MemcacheParser::ADD:
strcpy(cmd_name, "SET");
strcpy(set_opt, "NX");
break;
default:
cntx->SendMCClientError("bad command line format");
return;
}
args.emplace_back(cmd_name, strlen(cmd_name));
char* key = const_cast<char*>(cmd.key.data());
args.emplace_back(key, cmd.key.size());
if (MemcacheParser::IsStoreCmd(cmd.type)) {
char* v = const_cast<char*>(value.data());
args.emplace_back(v, value.size());
if (set_opt[0]) {
args.emplace_back(set_opt, strlen(set_opt));
}
}
CmdArgList arg_list{args.data(), args.size()};
DispatchCommand(arg_list, cntx);
}
void Service::RegisterHttp(HttpListenerBase* listener) { void Service::RegisterHttp(HttpListenerBase* listener) {
CHECK_NOTNULL(listener); CHECK_NOTNULL(listener);
} }
@ -138,12 +180,12 @@ void Service::Ping(CmdArgList args, ConnectionContext* cntx) {
ping_qps.Inc(); ping_qps.Inc();
if (args.size() == 1) { if (args.size() == 1) {
return cntx->SendSimpleString("PONG"); return cntx->SendSimpleRespString("PONG");
} }
std::string_view arg = ArgS(args, 1); std::string_view arg = ArgS(args, 1);
DVLOG(2) << "Ping " << arg; DVLOG(2) << "Ping " << arg;
return cntx->SendSimpleString(arg); return cntx->SendSimpleRespString(arg);
} }
void Service::Set(CmdArgList args, ConnectionContext* cntx) { void Service::Set(CmdArgList args, ConnectionContext* cntx) {
@ -159,7 +201,8 @@ void Service::Set(CmdArgList args, ConnectionContext* cntx) {
auto [it, res] = es->db_slice.AddOrFind(0, key); auto [it, res] = es->db_slice.AddOrFind(0, key);
it->second = val; it->second = val;
}); });
cntx->SendOk();
cntx->SendStored();
} }

View File

@ -8,6 +8,7 @@
#include "server/command_registry.h" #include "server/command_registry.h"
#include "server/engine_shard_set.h" #include "server/engine_shard_set.h"
#include "util/http/http_handler.h" #include "util/http/http_handler.h"
#include "server/memcache_parser.h"
namespace util { namespace util {
class AcceptServer; class AcceptServer;
@ -29,6 +30,8 @@ class Service {
void Shutdown(); void Shutdown();
void DispatchCommand(CmdArgList args, ConnectionContext* cntx); void DispatchCommand(CmdArgList args, ConnectionContext* cntx);
void DispatchMC(const MemcacheParser::Command& cmd, std::string_view value,
ConnectionContext* cntx);
uint32_t shard_count() const { uint32_t shard_count() const {
return shard_set_.size(); return shard_set_.size();

View File

@ -26,23 +26,38 @@ constexpr char kSimplePref[] = "+";
} // namespace } // namespace
RespSerializer::RespSerializer(io::Sink* stream) : sink_(stream) { BaseSerializer::BaseSerializer(io::Sink* sink) : sink_(sink) {
} }
void RespSerializer::Send(const iovec* v, uint32_t len) { void BaseSerializer::Send(const iovec* v, uint32_t len) {
error_code ec = sink_->Write(v, len); error_code ec = sink_->Write(v, len);
if (ec) { if (ec) {
ec_ = ec; ec_ = ec;
} }
} }
void RespSerializer::SendDirect(std::string_view raw) { void BaseSerializer::SendDirect(std::string_view raw) {
iovec v = {IoVec(raw)}; iovec v = {IoVec(raw)};
Send(&v, 1); Send(&v, 1);
} }
void ReplyBuilder::SendBulkString(std::string_view str) { void RespSerializer::SendNull() {
constexpr char kNullStr[] = "$-1\r\n";
iovec v[] = {IoVec(kNullStr)};
Send(v, ABSL_ARRAYSIZE(v));
}
void RespSerializer::SendSimpleString(std::string_view str) {
iovec v[3] = {IoVec(kSimplePref), IoVec(str), IoVec(kCRLF)};
Send(v, ABSL_ARRAYSIZE(v));
}
void RespSerializer::SendBulkString(std::string_view str) {
char tmp[absl::numbers_internal::kFastToBufferSize + 3]; char tmp[absl::numbers_internal::kFastToBufferSize + 3];
tmp[0] = '$'; // Format length tmp[0] = '$'; // Format length
char* next = absl::numbers_internal::FastIntToBuffer(uint32_t(str.size()), tmp + 1); char* next = absl::numbers_internal::FastIntToBuffer(uint32_t(str.size()), tmp + 1);
@ -57,17 +72,54 @@ void ReplyBuilder::SendBulkString(std::string_view str) {
return Send(v, ABSL_ARRAYSIZE(v)); return Send(v, ABSL_ARRAYSIZE(v));
} }
void ReplyBuilder::SendError(std::string_view str) { void MemcacheSerializer::SendStored() {
SendDirect("STORED\r\n");
}
void MemcacheSerializer::SendError() {
SendDirect("ERROR\r\n");
}
ReplyBuilder::ReplyBuilder(Protocol protocol, ::io::Sink* sink) : protocol_(protocol) {
if (protocol == Protocol::REDIS) {
serializer_.reset(new RespSerializer(sink));
} else {
DCHECK(protocol == Protocol::MEMCACHE);
serializer_.reset(new MemcacheSerializer(sink));
}
}
void ReplyBuilder::SendStored() {
if (protocol_ == Protocol::REDIS) {
as_resp()->SendSimpleString("OK");
} else {
as_mc()->SendStored();
}
}
void ReplyBuilder::SendMCClientError(string_view str) {
DCHECK(protocol_ == Protocol::MEMCACHE);
iovec v[] = {IoVec("CLIENT_ERROR"), IoVec(str), IoVec(kCRLF)};
serializer_->Send(v, ABSL_ARRAYSIZE(v));
}
void ReplyBuilder::SendError(string_view str) {
DCHECK(protocol_ == Protocol::REDIS);
if (str[0] == '-') { if (str[0] == '-') {
iovec v[] = {IoVec(str), IoVec(kCRLF)}; iovec v[] = {IoVec(str), IoVec(kCRLF)};
return Send(v, ABSL_ARRAYSIZE(v)); serializer_->Send(v, ABSL_ARRAYSIZE(v));
} else { } else {
iovec v[] = {IoVec(kErrPref), IoVec(str), IoVec(kCRLF)}; iovec v[] = {IoVec(kErrPref), IoVec(str), IoVec(kCRLF)};
return Send(v, ABSL_ARRAYSIZE(v)); serializer_->Send(v, ABSL_ARRAYSIZE(v));
} }
} }
void ReplyBuilder::SendError(OpStatus status) { void ReplyBuilder::SendError(OpStatus status) {
DCHECK(protocol_ == Protocol::REDIS);
switch (status) { switch (status) {
case OpStatus::OK: case OpStatus::OK:
SendOk(); SendOk();
@ -82,18 +134,4 @@ void ReplyBuilder::SendError(OpStatus status) {
} }
} }
void ReplyBuilder::SendNull() {
constexpr char kNullStr[] = "$-1\r\n";
iovec v[] = {IoVec(kNullStr)};
Send(v, ABSL_ARRAYSIZE(v));
}
void ReplyBuilder::SendSimpleString(std::string_view str) {
iovec v[3] = {IoVec(kSimplePref), IoVec(str), IoVec(kCRLF)};
Send(v, ABSL_ARRAYSIZE(v));
}
} // namespace dfly } // namespace dfly

View File

@ -2,17 +2,16 @@
// Author: Roman Gershman (romange@gmail.com) // Author: Roman Gershman (romange@gmail.com)
// //
#include <string_view> #include <string_view>
#include <optional>
#include "io/sync_stream_interface.h" #include "io/sync_stream_interface.h"
#include "server/dfly_protocol.h"
#include "server/op_status.h" #include "server/op_status.h"
namespace dfly { namespace dfly {
class RespSerializer { class BaseSerializer {
public: public:
explicit RespSerializer(::io::Sink* sink); explicit BaseSerializer(::io::Sink* sink);
std::error_code ec() const { std::error_code ec() const {
return ec_; return ec_;
@ -23,41 +22,78 @@ class RespSerializer {
ec_ = std::make_error_code(std::errc::connection_aborted); ec_ = std::make_error_code(std::errc::connection_aborted);
} }
//! Sends a string as is without any formatting. raw should be RESP-encoded. //! Sends a string as is without any formatting. raw should be encoded according to the protocol.
void SendDirect(std::string_view str); void SendDirect(std::string_view str);
::io::Sink* sink() { return sink_; } ::io::Sink* sink() {
return sink_;
}
protected:
void Send(const iovec* v, uint32_t len); void Send(const iovec* v, uint32_t len);
::io::Sink* sink_;
private: private:
::io::Sink* sink_;
std::error_code ec_; std::error_code ec_;
}; };
class ReplyBuilder : public RespSerializer { class RespSerializer : public BaseSerializer {
public: public:
explicit ReplyBuilder(::io::Sink* stream) : RespSerializer(stream) { RespSerializer(::io::Sink* sink) : BaseSerializer(sink) {
} }
//! See https://redis.io/topics/protocol
void SendSimpleString(std::string_view str);
void SendNull();
/// aka "$6\r\nfoobar\r\n" /// aka "$6\r\nfoobar\r\n"
void SendBulkString(std::string_view str); void SendBulkString(std::string_view str);
};
void SendNull(); class MemcacheSerializer : public BaseSerializer {
public:
void SendOk() { explicit MemcacheSerializer(::io::Sink* sink) : BaseSerializer(sink) {
return SendSimpleString("OK");
} }
void SendStored();
void SendError();
};
class ReplyBuilder {
public:
ReplyBuilder(Protocol protocol, ::io::Sink* stream);
void SendStored();
void SendError(std::string_view str); void SendError(std::string_view str);
void SendError(OpStatus status); void SendError(OpStatus status);
//! See https://redis.io/topics/protocol void SendOk() {
void SendSimpleString(std::string_view str); as_resp()->SendSimpleString("OK");
}
private: std::error_code ec() const {
return serializer_->ec();
}
void SendMCClientError(std::string_view str);
void SendSimpleRespString(std::string_view str) {
as_resp()->SendSimpleString(str);
}
void SendRespBlob(std::string_view str) {
as_resp()->SendDirect(str);
}
private:
RespSerializer* as_resp() {
return static_cast<RespSerializer*>(serializer_.get());
}
MemcacheSerializer* as_mc() {
return static_cast<MemcacheSerializer*>(serializer_.get());
}
std::unique_ptr<BaseSerializer> serializer_;
Protocol protocol_;
}; };
} // namespace dfly } // namespace dfly