// Copyright 2021, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #include "server/dragonfly_connection.h" #include #include #include "base/io_buf.h" #include "base/logging.h" #include "server/command_registry.h" #include "server/conn_context.h" #include "server/main_service.h" #include "server/memcache_parser.h" #include "server/redis_parser.h" #include "server/server_state.h" #include "util/fiber_sched_algo.h" #include "util/tls/tls_socket.h" using namespace util; using namespace std; namespace this_fiber = boost::this_fiber; namespace fibers = boost::fibers; namespace dfly { namespace { void SendProtocolError(RedisParser::Result pres, FiberSocketBase* peer) { string res("-ERR Protocol error: "); if (pres == RedisParser::BAD_BULKLEN) { res.append("invalid bulk length\r\n"); } else { CHECK_EQ(RedisParser::BAD_ARRAYLEN, pres); res.append("invalid multibulk length\r\n"); } auto size_res = peer->Send(::io::Buffer(res)); if (!size_res) { LOG(WARNING) << "Error " << size_res.error(); } } void RespToArgList(const RespVec& src, CmdArgVec* dest) { dest->resize(src.size()); for (size_t i = 0; i < src.size(); ++i) { (*dest)[i] = ToMSS(src[i].GetBuf()); } } constexpr size_t kMinReadSize = 256; constexpr size_t kMaxReadSize = 32_KB; } // namespace struct Connection::Shutdown { absl::flat_hash_map map; ShutdownHandle next_handle = 1; ShutdownHandle Add(ShutdownCb cb) { map[next_handle] = move(cb); return next_handle++; } void Remove(ShutdownHandle sh) { map.erase(sh); } }; Connection::Connection(Protocol protocol, Service* service, SSL_CTX* ctx) : service_(service), ctx_(ctx) { protocol_ = protocol; switch (protocol) { case Protocol::REDIS: redis_parser_.reset(new RedisParser); break; case Protocol::MEMCACHE: memcache_parser_.reset(new MemcacheParser); break; } } Connection::~Connection() { } void Connection::OnShutdown() { VLOG(1) << "Connection::OnShutdown"; if (shutdown_) { for (const auto& k_v : shutdown_->map) { k_v.second(); } } } auto Connection::RegisterShutdownHook(ShutdownCb cb) -> ShutdownHandle { if (!shutdown_) { shutdown_ = make_unique(); } return shutdown_->Add(std::move(cb)); } void Connection::UnregisterShutdownHook(ShutdownHandle id) { if (shutdown_) { shutdown_->Remove(id); if (shutdown_->map.empty()) shutdown_.reset(); } } void Connection::HandleRequests() { this_fiber::properties().set_name("DflyConnection"); int val = 1; CHECK_EQ(0, setsockopt(socket_->native_handle(), SOL_TCP, TCP_NODELAY, &val, sizeof(val))); auto remote_ep = socket_->RemoteEndpoint(); std::unique_ptr tls_sock; if (ctx_) { tls_sock.reset(new tls::TlsSocket(socket_.get())); tls_sock->InitSSL(ctx_); FiberSocketBase::AcceptResult aresult = tls_sock->Accept(); if (!aresult) { LOG(WARNING) << "Error handshaking " << aresult.error().message(); return; } VLOG(1) << "TLS handshake succeeded"; } FiberSocketBase* peer = tls_sock ? (FiberSocketBase*)tls_sock.get() : socket_.get(); cc_.reset(new ConnectionContext(peer, this)); cc_->shard_set = &service_->shard_set(); InputLoop(peer); VLOG(1) << "Closed connection for peer " << remote_ep; } void Connection::InputLoop(FiberSocketBase* peer) { base::IoBuf io_buf{kMinReadSize}; auto dispatch_fb = fibers::fiber(fibers::launch::dispatch, [&] { DispatchFiber(peer); }); ConnectionStats* stats = ServerState::tl_connection_stats(); stats->num_conns++; stats->read_buf_capacity += io_buf.Capacity(); ParserStatus status = OK; std::error_code ec; do { auto buf = io_buf.AppendBuffer(); ::io::Result recv_sz = peer->Recv(buf); ++stats->io_reads_cnt; if (!recv_sz) { ec = recv_sz.error(); status = OK; break; } io_buf.CommitWrite(*recv_sz); if (redis_parser_) status = ParseRedis(&io_buf); else { DCHECK(memcache_parser_); status = ParseMemcache(&io_buf); } if (status == NEED_MORE) { status = OK; size_t capacity = io_buf.Capacity(); if (capacity < kMaxReadSize) { size_t parser_hint = redis_parser_->parselen_hint(); if (parser_hint > capacity) { io_buf.Reserve(std::min(kMaxReadSize, parser_hint)); } else if (buf.size() == *recv_sz && buf.size() > capacity / 2) { // Last io used most of the io_buf to the end. io_buf.Reserve(capacity * 2); // Valid growth range. } if (capacity < io_buf.Capacity()) { VLOG(1) << "Growing io_buf to " << io_buf.Capacity(); stats->read_buf_capacity += (io_buf.Capacity() - capacity); } } } else if (status != OK) { break; } } while (peer->IsOpen() && !cc_->ec()); cc_->conn_state.mask |= ConnectionState::CONN_CLOSING; // Signal dispatch to close. evc_.notify(); dispatch_fb.join(); stats->read_buf_capacity -= io_buf.Capacity(); if (cc_->ec()) { ec = cc_->ec(); } else { if (status == ERROR) { VLOG(1) << "Error stats " << status; if (redis_parser_) { SendProtocolError(RedisParser::Result(parser_error_), peer); } else { string_view sv{"CLIENT_ERROR bad command line format\r\n"}; auto size_res = peer->Send(::io::Buffer(sv)); if (!size_res) { LOG(WARNING) << "Error " << size_res.error(); ec = size_res.error(); } } } } if (ec && !FiberSocketBase::IsConnClosed(ec)) { LOG(WARNING) << "Socket error " << ec; } --stats->num_conns; } auto Connection::ParseRedis(base::IoBuf* io_buf) -> ParserStatus { RespVec args; CmdArgVec arg_vec; uint32_t consumed = 0; RedisParser::Result result = RedisParser::OK; do { result = redis_parser_->Parse(io_buf->InputBuffer(), &consumed, &args); if (result == RedisParser::OK && !args.empty()) { RespExpr& first = args.front(); if (first.type == RespExpr::STRING) { DVLOG(2) << "Got Args with first token " << ToSV(first.GetBuf()); } // An optimization to skip dispatch_q_ if no pipelining is identified. // We use ASYNC_DISPATCH as a lock to avoid out-of-order replies when the // dispatch fiber pulls the last record but is still processing the command and then this // fiber enters the condition below and executes out of order. bool is_sync_dispatch = !cc_->conn_state.IsRunViaDispatch(); if (dispatch_q_.empty() && is_sync_dispatch && consumed >= io_buf->InputLen()) { RespToArgList(args, &arg_vec); service_->DispatchCommand(CmdArgList{arg_vec.data(), arg_vec.size()}, cc_.get()); } else { // Dispatch via queue to speedup input reading, Request* req = FromArgs(std::move(args)); dispatch_q_.emplace_back(req); if (dispatch_q_.size() == 1) { evc_.notify(); } else if (dispatch_q_.size() > 10) { this_fiber::yield(); } } } io_buf->ConsumeInput(consumed); } while (RedisParser::OK == result && !cc_->ec()); parser_error_ = result; if (result == RedisParser::OK) return OK; if (result == RedisParser::INPUT_PENDING) return NEED_MORE; return ERROR; } auto Connection::ParseMemcache(base::IoBuf* io_buf) -> ParserStatus { MemcacheParser::Result result = MemcacheParser::OK; uint32_t consumed = 0; MemcacheParser::Command cmd; string_view value; do { string_view str = ToSV(io_buf->InputBuffer()); result = memcache_parser_->Parse(str, &consumed, &cmd); if (result != MemcacheParser::OK) { io_buf->ConsumeInput(consumed); break; } size_t total_len = consumed; if (MemcacheParser::IsStoreCmd(cmd.type)) { total_len += cmd.bytes_len + 2; if (io_buf->InputLen() >= total_len) { value = str.substr(consumed, cmd.bytes_len); // TODO: dispatch. } else { return NEED_MORE; } } // An optimization to skip dispatch_q_ if no pipelining is identified. // We use ASYNC_DISPATCH as a lock to avoid out-of-order replies when the // dispatch fiber pulls the last record but is still processing the command and then this // fiber enters the condition below and executes out of order. bool is_sync_dispatch = (cc_->conn_state.mask & ConnectionState::ASYNC_DISPATCH) == 0; if (dispatch_q_.empty() && is_sync_dispatch && consumed >= io_buf->InputLen()) { service_->DispatchMC(cmd, value, cc_.get()); } io_buf->ConsumeInput(consumed); } while (!cc_->ec()); parser_error_ = result; if (result == MemcacheParser::OK) return OK; if (result == MemcacheParser::INPUT_PENDING) return NEED_MORE; return ERROR; } // DispatchFiber handles commands coming from the InputLoop. // Thus, InputLoop can quickly read data from the input buffer, parse it and push // into the dispatch queue and DispatchFiber will run those commands asynchronously with InputLoop. // Note: in some cases, InputLoop may decide to dispatch directly and bypass the DispatchFiber. void Connection::DispatchFiber(util::FiberSocketBase* peer) { this_fiber::properties().set_name("DispatchFiber"); ConnectionStats* stats = ServerState::tl_connection_stats(); while (!cc_->ec()) { evc_.await([this] { return cc_->conn_state.IsClosing() || !dispatch_q_.empty(); }); if (cc_->conn_state.IsClosing()) break; std::unique_ptr req{dispatch_q_.front()}; dispatch_q_.pop_front(); ++stats->pipelined_cmd_cnt; cc_->SetBatchMode(!dispatch_q_.empty()); cc_->conn_state.mask |= ConnectionState::ASYNC_DISPATCH; service_->DispatchCommand(CmdArgList{req->args.data(), req->args.size()}, cc_.get()); cc_->conn_state.mask &= ~ConnectionState::ASYNC_DISPATCH; } cc_->conn_state.mask |= ConnectionState::CONN_CLOSING; } auto Connection::FromArgs(RespVec args) -> Request* { DCHECK(!args.empty()); size_t backed_sz = 0; for (const auto& arg : args) { CHECK_EQ(RespExpr::STRING, arg.type); backed_sz += arg.GetBuf().size(); } DCHECK(backed_sz); Request* req = new Request{args.size(), backed_sz}; auto* next = req->storage.data(); for (size_t i = 0; i < args.size(); ++i) { auto buf = args[i].GetBuf(); size_t s = buf.size(); memcpy(next, buf.data(), s); req->args[i] = MutableStrSpan(next, s); next += s; } return req; } } // namespace dfly