diff --git a/server/dragonfly_connection.cc b/server/dragonfly_connection.cc index a12ecee..a0324a6 100644 --- a/server/dragonfly_connection.cc +++ b/server/dragonfly_connection.cc @@ -5,10 +5,10 @@ #include "server/dragonfly_connection.h" #include +#include #include -#include "base/io_buf.h" #include "base/logging.h" #include "server/command_registry.h" #include "server/conn_context.h" @@ -16,11 +16,14 @@ #include "server/memcache_parser.h" #include "server/redis_parser.h" #include "server/server_state.h" +#include "server/transaction.h" #include "util/fiber_sched_algo.h" #include "util/tls/tls_socket.h" +#include "util/uring/uring_socket.h" using namespace util; using namespace std; +using nonstd::make_unexpected; namespace this_fiber = boost::this_fiber; namespace fibers = boost::fibers; @@ -49,6 +52,13 @@ void RespToArgList(const RespVec& src, CmdArgVec* dest) { } } +// TODO: to implement correct matcher according to HTTP spec +// https://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html +// One place to find a good implementation would be https://github.com/h2o/picohttpparser +bool MatchHttp11Line(string_view line) { + return absl::StartsWith(line, "GET ") && absl::EndsWith(line, "HTTP/1.1"); +} + constexpr size_t kMinReadSize = 256; constexpr size_t kMaxReadSize = 32_KB; @@ -69,7 +79,7 @@ struct Connection::Shutdown { }; Connection::Connection(Protocol protocol, Service* service, SSL_CTX* ctx) - : service_(service), ctx_(ctx) { + : io_buf_{kMinReadSize}, service_(service), ctx_(ctx) { protocol_ = protocol; switch (protocol) { @@ -134,76 +144,134 @@ void Connection::HandleRequests() { cc_.reset(new ConnectionContext(peer, this)); cc_->shard_set = &service_->shard_set(); + // TODO: to move this interface to LinuxSocketBase so we won't need to cast. + uring::UringSocket* us = static_cast(socket_.get()); + + bool poll_armed = true; + uint32_t poll_id = us->PollEvent(POLLERR | POLLHUP, [&](uint32_t mask) { + VLOG(1) << "Got event " << mask; + cc_->conn_state.mask |= ConnectionState::CONN_CLOSING; + if (cc_->transaction) { + cc_->transaction->BreakOnClose(); + } + + evc_.notify(); // Notify dispatch fiber. + poll_armed = false; + }); + + io::Result check_res = CheckForHttpProto(peer); + if (!check_res) + return; + if (*check_res) { + LOG(INFO) << "HTTP1.1 identified"; + } + InputLoop(peer); + if (poll_armed) { + us->CancelPoll(poll_id); + } VLOG(1) << "Closed connection for peer " << remote_ep; } -void Connection::InputLoop(FiberSocketBase* peer) { - base::IoBuf io_buf{kMinReadSize}; +io::Result Connection::CheckForHttpProto(util::FiberSocketBase* peer) { + size_t last_len = 0; + do { + auto buf = io_buf_.AppendBuffer(); + ::io::Result recv_sz = peer->Recv(buf); + if (!recv_sz) { + return make_unexpected(recv_sz.error()); + } + io_buf_.CommitWrite(*recv_sz); + string_view ib = ToSV(io_buf_.InputBuffer().subspan(last_len)); + size_t pos = ib.find('\n'); + if (pos != string_view::npos) { + ib = ToSV(io_buf_.InputBuffer().first(last_len + pos)); + if (ib.size() < 10 || ib.back() != '\r') + return false; + ib.remove_suffix(1); + return MatchHttp11Line(ib); + } + last_len = io_buf_.InputLen(); + } while (last_len < 1024); + return false; +} + +void Connection::InputLoop(FiberSocketBase* peer) { auto dispatch_fb = fibers::fiber(fibers::launch::dispatch, [&] { DispatchFiber(peer); }); ConnectionStats* stats = ServerState::tl_connection_stats(); stats->num_conns++; - stats->read_buf_capacity += io_buf.Capacity(); + stats->read_buf_capacity += io_buf_.Capacity(); - ParserStatus status = OK; + ParserStatus parse_status = OK; std::error_code ec; + if (io_buf_.InputLen() > 0) { + if (redis_parser_) { + parse_status = ParseRedis(); + } else { + DCHECK(memcache_parser_); + parse_status = ParseMemcache(); + } + if (parse_status == ERROR) + goto finish; + } + do { - auto buf = io_buf.AppendBuffer(); - ::io::Result recv_sz = peer->Recv(buf); + io::MutableBytes append_buf = io_buf_.AppendBuffer(); + ::io::Result recv_sz = peer->Recv(append_buf); ++stats->io_reads_cnt; if (!recv_sz) { ec = recv_sz.error(); - status = OK; + parse_status = OK; break; } - - io_buf.CommitWrite(*recv_sz); + io_buf_.CommitWrite(*recv_sz); if (redis_parser_) - status = ParseRedis(&io_buf); + parse_status = ParseRedis(); else { DCHECK(memcache_parser_); - status = ParseMemcache(&io_buf); + parse_status = ParseMemcache(); } - if (status == NEED_MORE) { - status = OK; + if (parse_status == NEED_MORE) { + parse_status = OK; - size_t capacity = io_buf.Capacity(); + size_t capacity = io_buf_.Capacity(); if (capacity < kMaxReadSize) { size_t parser_hint = redis_parser_->parselen_hint(); if (parser_hint > capacity) { - io_buf.Reserve(std::min(kMaxReadSize, parser_hint)); - } else if (buf.size() == *recv_sz && buf.size() > capacity / 2) { + io_buf_.Reserve(std::min(kMaxReadSize, parser_hint)); + } else if (append_buf.size() == *recv_sz && append_buf.size() > capacity / 2) { // Last io used most of the io_buf to the end. - io_buf.Reserve(capacity * 2); // Valid growth range. + io_buf_.Reserve(capacity * 2); // Valid growth range. } - if (capacity < io_buf.Capacity()) { - VLOG(1) << "Growing io_buf to " << io_buf.Capacity(); - stats->read_buf_capacity += (io_buf.Capacity() - capacity); + if (capacity < io_buf_.Capacity()) { + VLOG(1) << "Growing io_buf to " << io_buf_.Capacity(); + stats->read_buf_capacity += (io_buf_.Capacity() - capacity); } } - } else if (status != OK) { + } else if (parse_status != OK) { break; } } while (peer->IsOpen() && !cc_->ec()); +finish: cc_->conn_state.mask |= ConnectionState::CONN_CLOSING; // Signal dispatch to close. evc_.notify(); dispatch_fb.join(); - stats->read_buf_capacity -= io_buf.Capacity(); + stats->read_buf_capacity -= io_buf_.Capacity(); if (cc_->ec()) { ec = cc_->ec(); } else { - if (status == ERROR) { - VLOG(1) << "Error stats " << status; + if (parse_status == ERROR) { + VLOG(1) << "Error stats " << parse_status; if (redis_parser_) { SendProtocolError(RedisParser::Result(parser_error_), peer); } else { @@ -224,7 +292,7 @@ void Connection::InputLoop(FiberSocketBase* peer) { --stats->num_conns; } -auto Connection::ParseRedis(base::IoBuf* io_buf) -> ParserStatus { +auto Connection::ParseRedis() -> ParserStatus { RespVec args; CmdArgVec arg_vec; uint32_t consumed = 0; @@ -232,7 +300,7 @@ auto Connection::ParseRedis(base::IoBuf* io_buf) -> ParserStatus { RedisParser::Result result = RedisParser::OK; do { - result = redis_parser_->Parse(io_buf->InputBuffer(), &consumed, &args); + result = redis_parser_->Parse(io_buf_.InputBuffer(), &consumed, &args); if (result == RedisParser::OK && !args.empty()) { RespExpr& first = args.front(); @@ -245,7 +313,7 @@ auto Connection::ParseRedis(base::IoBuf* io_buf) -> ParserStatus { // dispatch fiber pulls the last record but is still processing the command and then this // fiber enters the condition below and executes out of order. bool is_sync_dispatch = !cc_->conn_state.IsRunViaDispatch(); - if (dispatch_q_.empty() && is_sync_dispatch && consumed >= io_buf->InputLen()) { + if (dispatch_q_.empty() && is_sync_dispatch && consumed >= io_buf_.InputLen()) { RespToArgList(args, &arg_vec); service_->DispatchCommand(CmdArgList{arg_vec.data(), arg_vec.size()}, cc_.get()); } else { @@ -259,7 +327,7 @@ auto Connection::ParseRedis(base::IoBuf* io_buf) -> ParserStatus { } } } - io_buf->ConsumeInput(consumed); + io_buf_.ConsumeInput(consumed); } while (RedisParser::OK == result && !cc_->ec()); parser_error_ = result; @@ -272,24 +340,24 @@ auto Connection::ParseRedis(base::IoBuf* io_buf) -> ParserStatus { return ERROR; } -auto Connection::ParseMemcache(base::IoBuf* io_buf) -> ParserStatus { +auto Connection::ParseMemcache() -> ParserStatus { MemcacheParser::Result result = MemcacheParser::OK; uint32_t consumed = 0; MemcacheParser::Command cmd; string_view value; do { - string_view str = ToSV(io_buf->InputBuffer()); + string_view str = ToSV(io_buf_.InputBuffer()); result = memcache_parser_->Parse(str, &consumed, &cmd); if (result != MemcacheParser::OK) { - io_buf->ConsumeInput(consumed); + io_buf_.ConsumeInput(consumed); break; } size_t total_len = consumed; if (MemcacheParser::IsStoreCmd(cmd.type)) { total_len += cmd.bytes_len + 2; - if (io_buf->InputLen() >= total_len) { + if (io_buf_.InputLen() >= total_len) { value = str.substr(consumed, cmd.bytes_len); // TODO: dispatch. } else { @@ -302,10 +370,10 @@ auto Connection::ParseMemcache(base::IoBuf* io_buf) -> ParserStatus { // dispatch fiber pulls the last record but is still processing the command and then this // fiber enters the condition below and executes out of order. bool is_sync_dispatch = (cc_->conn_state.mask & ConnectionState::ASYNC_DISPATCH) == 0; - if (dispatch_q_.empty() && is_sync_dispatch && consumed >= io_buf->InputLen()) { + if (dispatch_q_.empty() && is_sync_dispatch && consumed >= io_buf_.InputLen()) { service_->DispatchMC(cmd, value, cc_.get()); } - io_buf->ConsumeInput(consumed); + io_buf_.ConsumeInput(consumed); } while (!cc_->ec()); parser_error_ = result; @@ -335,7 +403,7 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) { std::unique_ptr req{dispatch_q_.front()}; dispatch_q_.pop_front(); - + ++stats->pipelined_cmd_cnt; cc_->SetBatchMode(!dispatch_q_.empty()); diff --git a/server/dragonfly_connection.h b/server/dragonfly_connection.h index be96b7b..152a89f 100644 --- a/server/dragonfly_connection.h +++ b/server/dragonfly_connection.h @@ -47,12 +47,15 @@ class Connection : public util::Connection { void HandleRequests() final; + // + io::Result CheckForHttpProto(util::FiberSocketBase* peer); void InputLoop(util::FiberSocketBase* peer); void DispatchFiber(util::FiberSocketBase* peer); - ParserStatus ParseRedis(base::IoBuf* buf); - ParserStatus ParseMemcache(base::IoBuf* buf); + ParserStatus ParseRedis(); + ParserStatus ParseMemcache(); + base::IoBuf io_buf_; std::unique_ptr redis_parser_; std::unique_ptr memcache_parser_; Service* service_;