PUBSUB: more polishes.

Implement atomic message passing that will allow handling commands in a subscribed state together with passing
message from publishers.
This commit is contained in:
Roman Gershman 2022-04-01 10:12:32 +03:00
parent 077ebe460d
commit ba71e9a943
11 changed files with 141 additions and 60 deletions

View File

@ -192,7 +192,9 @@ API 2.0
- [ ] PUBSUB CHANNELS
- [X] SUBSCRIBE
- [X] UNSUBSCRIBE
- [X] Server Family
- [ ] PSUBSCRIBE
- [ ] PUNSUBSCRIBE
- [ ] Server Family
- [ ] WATCH
- [ ] UNWATCH
- [X] DISCARD
@ -238,14 +240,6 @@ API 2.0
- [ ] PFADD
- [ ] PFCOUNT
- [ ] PFMERGE
- [ ] PUBSUB Family
- [ ] PSUBSCRIBE
- [ ] PUBSUB
- [ ] PUBLISH
- [ ] PUNSUBSCRIBE
- [ ] SUBSCRIBE
- [ ] UNSUBSCRIBE
Commands that I prefer avoid implementing before launch:
- PUNSUBSCRIBE

View File

@ -42,11 +42,12 @@ class ConnectionContext {
}
// connection state / properties.
bool async_dispatch: 1;
bool async_dispatch: 1; // whether this connection is currently handled by dispatch fiber.
bool conn_closing: 1;
bool req_auth: 1;
bool replica_conn: 1;
bool authenticated: 1;
bool force_dispatch: 1; // whether we should route all requests to the dispatch fiber.
virtual void OnClose() {}
private:

View File

@ -11,7 +11,6 @@
#include <boost/fiber/operations.hpp>
#include "base/logging.h"
#include "facade/conn_context.h"
#include "facade/memcache_parser.h"
#include "facade/redis_parser.h"
@ -73,6 +72,21 @@ bool MatchHttp11Line(string_view line) {
constexpr size_t kMinReadSize = 256;
constexpr size_t kMaxReadSize = 32_KB;
struct AsyncMsg {
absl::Span<const std::string_view> msg_vec;
fibers_ext::BlockingCounter bc;
AsyncMsg(absl::Span<const std::string_view> vec, fibers_ext::BlockingCounter b)
: msg_vec(vec), bc(move(b)) {
}
};
#ifdef ABSL_HAVE_ADDRESS_SANITIZER
constexpr size_t kReqStorageSize = 88;
#else
constexpr size_t kReqStorageSize = 120;
#endif
} // namespace
struct Connection::Shutdown {
@ -89,13 +103,15 @@ struct Connection::Shutdown {
}
};
struct Connection::Request {
absl::FixedArray<MutableSlice> args;
absl::FixedArray<MutableSlice, 6> args;
// I do not use mi_heap_t explicitly but mi_stl_allocator at the end does the same job
// of using the thread's heap.
// The capacity is chosen so that we allocate a fully utilized (512 bytes) block.
absl::FixedArray<char, 190, mi_stl_allocator<char>> storage;
// The capacity is chosen so that we allocate a fully utilized (256 bytes) block.
absl::FixedArray<char, kReqStorageSize, mi_stl_allocator<char>> storage;
AsyncMsg* async_msg = nullptr; // allocated and released via mi_malloc.
Request(size_t nargs, size_t capacity) : args(nargs), storage(capacity) {
}
@ -109,7 +125,8 @@ Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener,
protocol_ = protocol;
constexpr size_t kReqSz = sizeof(Connection::Request);
(void)kReqSz;
static_assert(kReqSz <= 256 && kReqSz >= 232);
// LOG(INFO) << "kReqSz: " << kReqSz;
switch (protocol) {
case Protocol::REDIS:
@ -223,7 +240,26 @@ void Connection::RegisterOnBreak(BreakerCb breaker_cb) {
breaker_cb_ = breaker_cb;
}
io::Result<bool> Connection::CheckForHttpProto(util::FiberSocketBase* peer) {
void Connection::SendMsgVecAsync(absl::Span<const std::string_view> msg_vec,
fibers_ext::BlockingCounter bc) {
if (cc_->conn_closing) {
bc.Dec();
return;
}
void* ptr = mi_malloc(sizeof(AsyncMsg));
AsyncMsg* amsg = new (ptr) AsyncMsg(msg_vec, move(bc));
ptr = mi_malloc(sizeof(Request));
Request* req = new (ptr) Request(0, 0);
req->async_msg = amsg;
dispatch_q_.push_back(req);
if (dispatch_q_.size() == 1) {
evc_.notify();
}
}
io::Result<bool> Connection::CheckForHttpProto(FiberSocketBase* peer) {
size_t last_len = 0;
do {
auto buf = io_buf_.AppendBuffer();
@ -338,16 +374,15 @@ auto Connection::ParseRedis() -> ParserStatus {
// We use ASYNC_DISPATCH as a lock to avoid out-of-order replies when the
// dispatch fiber pulls the last record but is still processing the command and then this
// fiber enters the condition below and executes out of order.
bool is_sync_dispatch = !cc_->async_dispatch;
bool is_sync_dispatch = !cc_->async_dispatch && !cc_->force_dispatch;
if (dispatch_q_.empty() && is_sync_dispatch && consumed >= io_buf_.InputLen()) {
RespToArgList(args, &arg_vec);
service_->DispatchCommand(CmdArgList{arg_vec.data(), arg_vec.size()}, cc_.get());
} else {
// Dispatch via queue to speedup input reading
// We could use
// Dispatch via queue to speedup input reading.
Request* req = FromArgs(std::move(args), tlh);
dispatch_q_.emplace_back(req);
dispatch_q_.push_back(req);
if (dispatch_q_.size() == 1) {
evc_.notify();
} else if (dispatch_q_.size() > 10) {
@ -505,12 +540,21 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) {
Request* req = dispatch_q_.front();
dispatch_q_.pop_front();
++stats->pipelined_cmd_cnt;
if (req->async_msg) {
++stats->async_writes_cnt;
builder->SendRawVec(req->async_msg->msg_vec);
req->async_msg->bc.Dec();
builder->SetBatchMode(!dispatch_q_.empty());
cc_->async_dispatch = true;
service_->DispatchCommand(CmdArgList{req->args.data(), req->args.size()}, cc_.get());
cc_->async_dispatch = false;
req->async_msg->~AsyncMsg();
mi_free(req->async_msg);
} else {
++stats->pipelined_cmd_cnt;
builder->SetBatchMode(!dispatch_q_.empty());
cc_->async_dispatch = true;
service_->DispatchCommand(CmdArgList{req->args.data(), req->args.size()}, cc_.get());
cc_->async_dispatch = false;
}
req->~Request();
mi_free(req);
}
@ -521,6 +565,12 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) {
while (!dispatch_q_.empty()) {
Request* req = dispatch_q_.front();
dispatch_q_.pop_front();
if (req->async_msg) {
req->async_msg->bc.Dec();
req->async_msg->~AsyncMsg();
mi_free(req->async_msg);
}
req->~Request();
mi_free(req);
}

View File

@ -13,7 +13,7 @@
#include "facade/facade_types.h"
#include "facade/resp_expr.h"
#include "util/connection.h"
#include "util/fibers/event_count.h"
#include "util/fibers/fibers_ext.h"
#include "util/http/http_handler.h"
typedef struct ssl_ctx_st SSL_CTX;
@ -28,8 +28,8 @@ class MemcacheParser;
class Connection : public util::Connection {
public:
Connection(Protocol protocol, util::HttpListenerBase* http_listener,
SSL_CTX* ctx, ServiceInterface* service);
Connection(Protocol protocol, util::HttpListenerBase* http_listener, SSL_CTX* ctx,
ServiceInterface* service);
~Connection();
using error_code = std::error_code;
@ -46,6 +46,12 @@ class Connection : public util::Connection {
using BreakerCb = std::function<void(uint32_t)>;
void RegisterOnBreak(BreakerCb breaker_cb);
// This interface is used to pass a raw message directly to the socket via zero-copy interface.
// Once the msg is sent "bc" will be decreased so that caller could release the underlying
// storage for the message.
void SendMsgVecAsync(absl::Span<const std::string_view> msg_vec,
util::fibers_ext::BlockingCounter bc);
protected:
void OnShutdown() override;
@ -80,6 +86,7 @@ class Connection : public util::Connection {
std::deque<Request*> dispatch_q_; // coordinated via evc_.
util::fibers_ext::EventCount evc_;
unsigned parser_error_ = 0;
Protocol protocol_;
struct Shutdown;

View File

@ -21,7 +21,7 @@ constexpr size_t kSizeConnStats = sizeof(ConnectionStats);
ConnectionStats& ConnectionStats::operator+=(const ConnectionStats& o) {
// To break this code deliberately if we add/remove a field to this struct.
static_assert(kSizeConnStats == 144);
static_assert(kSizeConnStats == 152);
ADD(num_conns);
ADD(num_replicas);
@ -30,8 +30,9 @@ ConnectionStats& ConnectionStats::operator+=(const ConnectionStats& o) {
ADD(io_read_bytes);
ADD(io_write_cnt);
ADD(io_write_bytes);
ADD(pipelined_cmd_cnt);
ADD(command_cnt);
ADD(pipelined_cmd_cnt);
ADD(async_writes_cnt);
for (const auto& k_v : o.err_count) {
err_count[k_v.first] += k_v.second;
@ -98,6 +99,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, Connection* owner) : ow
req_auth = false;
replica_conn = false;
authenticated = false;
force_dispatch = false;
}
Protocol ConnectionContext::protocol() const {

View File

@ -33,6 +33,9 @@ struct ConnectionStats {
size_t command_cnt = 0;
size_t pipelined_cmd_cnt = 0;
// Writes count that happenned via SendRawMessageAsync call.
size_t async_writes_cnt = 0;
ConnectionStats& operator+=(const ConnectionStats& o);
};

View File

@ -75,12 +75,22 @@ void SinkReplyBuilder::Send(const iovec* v, uint32_t len) {
}
}
void SinkReplyBuilder::SendDirect(std::string_view raw) {
void SinkReplyBuilder::SendRaw(std::string_view raw) {
iovec v = {IoVec(raw)};
Send(&v, 1);
}
void SinkReplyBuilder::SendRawVec(absl::Span<const std::string_view> msg_vec) {
iovec v[msg_vec.size()];
for (unsigned i = 0; i < msg_vec.size(); ++i) {
v[i].iov_base = const_cast<char*>(msg_vec[i].data());
v[i].iov_len = msg_vec[i].size();
}
Send(v, msg_vec.size());
}
MCReplyBuilder::MCReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) {
}
@ -210,7 +220,7 @@ void RedisReplyBuilder::SendError(OpStatus status) {
void RedisReplyBuilder::SendLong(long num) {
string str = absl::StrCat(":", num, kCRLF);
SendDirect(str);
SendRaw(str);
}
void RedisReplyBuilder::SendDouble(double val) {
@ -228,7 +238,7 @@ void RedisReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) {
}
}
SendDirect(res);
SendRaw(res);
}
void RedisReplyBuilder::SendSimpleStrArr(const std::string_view* arr, uint32_t count) {
@ -238,11 +248,11 @@ void RedisReplyBuilder::SendSimpleStrArr(const std::string_view* arr, uint32_t c
StrAppend(&res, "+", arr[i], kCRLF);
}
SendDirect(res);
SendRaw(res);
}
void RedisReplyBuilder::SendNullArray() {
SendDirect("*-1\r\n");
SendRaw("*-1\r\n");
}
void RedisReplyBuilder::SendStringArr(absl::Span<const std::string_view> arr) {
@ -252,7 +262,7 @@ void RedisReplyBuilder::SendStringArr(absl::Span<const std::string_view> arr) {
StrAppend(&res, "$", arr[i].size(), kCRLF);
res.append(arr[i]).append(kCRLF);
}
SendDirect(res);
SendRaw(res);
}
void RedisReplyBuilder::SendStringArr(absl::Span<const string> arr) {
@ -262,11 +272,11 @@ void RedisReplyBuilder::SendStringArr(absl::Span<const string> arr) {
StrAppend(&res, "$", arr[i].size(), kCRLF);
res.append(arr[i]).append(kCRLF);
}
SendDirect(res);
SendRaw(res);
}
void RedisReplyBuilder::StartArray(unsigned len) {
SendDirect(absl::StrCat("*", len, kCRLF));
SendRaw(absl::StrCat("*", len, kCRLF));
}
void ReqSerializer::SendCommand(std::string_view str) {

View File

@ -54,7 +54,8 @@ class SinkReplyBuilder {
}
//! Sends a string as is without any formatting. raw should be encoded according to the protocol.
void SendDirect(std::string_view str);
void SendRaw(std::string_view str);
void SendRawVec(absl::Span<const std::string_view> msg_vec);
// Common for both MC and Redis.
virtual void SendError(std::string_view str, std::string_view type = std::string_view{}) = 0;

View File

@ -23,6 +23,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
DCHECK(to_add);
conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
this->force_dispatch = true;
}
for (size_t i = 0; i < args.size(); ++i) {
@ -45,6 +46,7 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
if (!to_add && conn_state.subscribe_info->channels.empty()) {
conn_state.subscribe_info.reset();
force_dispatch = false;
}
sort(channels.begin(), channels.end());

View File

@ -10,6 +10,7 @@ extern "C" {
#include <absl/cleanup/cleanup.h>
#include <absl/strings/ascii.h>
#include <absl/strings/str_format.h>
#include <xxhash.h>
#include <boost/fiber/operations.hpp>
@ -874,41 +875,50 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
auto cb = [&] { return EngineShard::tlocal()->channel_slice().FetchSubscribers(channel); };
vector<ChannelSlice::Subscriber> res = shard_set_.Await(sid, std::move(cb));
vector<ChannelSlice::Subscriber> subsriber_arr = shard_set_.Await(sid, std::move(cb));
atomic_uint32_t published{0};
if (!res.empty()) {
sort(res.begin(), res.end(),
if (!subsriber_arr.empty()) {
sort(subsriber_arr.begin(), subsriber_arr.end(),
[](const auto& left, const auto& right) { return left.thread_id < right.thread_id; });
vector<unsigned> slices(shard_set_.pool()->size(), UINT_MAX);
for (size_t i = 0; i < res.size(); ++i) {
if (slices[res[i].thread_id] > i) {
slices[res[i].thread_id] = i;
for (size_t i = 0; i < subsriber_arr.size(); ++i) {
if (slices[subsriber_arr[i].thread_id] > i) {
slices[subsriber_arr[i].thread_id] = i;
}
}
auto cb = [&](unsigned idx, util::ProactorBase*) {
fibers_ext::BlockingCounter bc(subsriber_arr.size());
char prefix[] = "*3\r\n$7\r\nmessage\r\n$";
char msg_size[32] = {0};
char channel_size[32] = {0};
absl::SNPrintF(msg_size, sizeof(msg_size), "%u\r\n", message.size());
absl::SNPrintF(channel_size, sizeof(channel_size), "%u\r\n", channel.size());
string_view msg_arr[] = {prefix, channel_size, channel, "\r\n$", msg_size, message, "\r\n"};
auto publish_cb = [&, bc](unsigned idx, util::ProactorBase*) mutable {
unsigned start = slices[idx];
for (unsigned i = start; i < res.size(); ++i) {
if (res[i].thread_id != idx)
for (unsigned i = start; i < subsriber_arr.size(); ++i) {
if (subsriber_arr[i].thread_id != idx)
break;
if (!res[i].conn_cntx->conn_closing) {
published.fetch_add(1, memory_order_relaxed);
// TODO: this is wrong because ReplyBuilder does not guarantee atomicity if used
// concurrently by multiple fibers.
string_view msg_arr[3] = {"message", channel, message};
(*res[i].conn_cntx)->SendStringArr(msg_arr);
}
published.fetch_add(1, memory_order_relaxed);
subsriber_arr[i].conn_cntx->owner()->SendMsgVecAsync(msg_arr, bc);
}
};
shard_set_.pool()->AwaitFiberOnAll(cb);
shard_set_.pool()->Await(publish_cb);
bc.Wait(); // Wait for all the messages to be sent.
}
for (auto& s : res) {
// If subsriber connections are closing they will wait
// for the tokens to be reclaimed in OnClose(). This guarantees that subscribers we gathered
// still exist till we finish publishing.
for (auto& s : subsriber_arr) {
s.borrow_token.Dec();
}

View File

@ -162,7 +162,7 @@ void ServerFamily::StatsMC(std::string_view section, facade::ConnectionContext*
absl::StrAppend(&info, "END\r\n");
MCReplyBuilder* builder = static_cast<MCReplyBuilder*>(cntx->reply_builder());
builder->SendDirect(info);
builder->SendRaw(info);
#undef ADD_LINE
}
@ -462,6 +462,7 @@ tcp_port:)";
absl::StrAppend(&info, "keyspace_misses:", -1, "\n");
absl::StrAppend(&info, "total_reads_processed:", m.conn_stats.io_read_cnt, "\n");
absl::StrAppend(&info, "total_writes_processed:", m.conn_stats.io_write_cnt, "\n");
absl::StrAppend(&info, "async_writes_count:", m.conn_stats.async_writes_cnt, "\n");
}
if (should_enter("REPLICATION")) {
@ -579,7 +580,7 @@ void ServerFamily::ReplicaOf(CmdArgList args, ConnectionContext* cntx) {
}
void ServerFamily::Role(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendDirect("*3\r\n$6\r\nmaster\r\n:0\r\n*0\r\n");
(*cntx)->SendRaw("*3\r\n$6\r\nmaster\r\n:0\r\n*0\r\n");
}
void ServerFamily::Script(CmdArgList args, ConnectionContext* cntx) {