diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..bf80c11 --- /dev/null +++ b/.clang-format @@ -0,0 +1,19 @@ +# --- +# We'll use defaults from the Google style, but with 2 columns indentation. +BasedOnStyle: Google +IndentWidth: 2 +ColumnLimit: 100 +--- +Language: Cpp +AllowShortLoopsOnASingleLine: false +AllowShortFunctionsOnASingleLine: false +AllowShortIfStatementsOnASingleLine: false +AlwaysBreakTemplateDeclarations: false +ConstructorInitializerAllOnOneLineOrOnePerLine: false +DerivePointerAlignment: false +PointerAlignment: Left +BasedOnStyle: Google +ColumnLimit: 100 +--- +Language: Proto +BasedOnStyle: Google diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index 651f46e..78144bd 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -1,7 +1,13 @@ add_executable(dragonfly dfly_main.cc) cxx_link(dragonfly base dragonfly_lib) -add_library(dragonfly_lib dragonfly_listener.cc dragonfly_connection.cc main_service.cc) +add_library(dragonfly_lib dragonfly_listener.cc dragonfly_connection.cc main_service.cc + redis_parser.cc resp_expr.cc) cxx_link(dragonfly_lib uring_fiber_lib fibers_ext strings_lib http_server_lib) + +add_library(dfly_test_lib test_utils.cc) +cxx_link(dfly_test_lib dragonfly_lib gtest_main_ext) + +cxx_test(redis_parser_test dfly_test_lib LABELS DFLY) diff --git a/server/redis_parser.cc b/server/redis_parser.cc new file mode 100644 index 0000000..8cfc09b --- /dev/null +++ b/server/redis_parser.cc @@ -0,0 +1,389 @@ +// Copyright 2021, Beeri 15. All rights reserved. +// Author: Roman Gershman (romange@gmail.com) +// +#include "server/redis_parser.h" + +#include + +#include "base/logging.h" + +namespace dfly { + +using namespace std; + +namespace { + +constexpr int kMaxArrayLen = 1024; +constexpr int64_t kMaxBulkLen = 64 * (1ul << 20); // 64MB. + +} // namespace + +auto RedisParser::Parse(Buffer str, uint32_t* consumed, RespExpr::Vec* res) -> Result { + *consumed = 0; + res->clear(); + + if (str.size() < 2) { + return INPUT_PENDING; + } + + if (state_ == CMD_COMPLETE_S) { + state_ = INIT_S; + } + + if (state_ == INIT_S) { + InitStart(str[0], res); + } + + if (!cached_expr_) + cached_expr_ = res; + + while (state_ != CMD_COMPLETE_S) { + last_consumed_ = 0; + switch (state_) { + case ARRAY_LEN_S: + last_result_ = ConsumeArrayLen(str); + break; + case PARSE_ARG_S: + if (str.size() < 4) { + last_result_ = INPUT_PENDING; + } else { + last_result_ = ParseArg(str); + } + break; + case INLINE_S: + DCHECK(parse_stack_.empty()); + last_result_ = ParseInline(str); + break; + case BULK_STR_S: + last_result_ = ConsumeBulk(str); + break; + case FINISH_ARG_S: + HandleFinishArg(); + break; + default: + LOG(FATAL) << "Unexpected state " << int(state_); + } + + *consumed += last_consumed_; + + if (last_result_ != OK) { + break; + } + str.remove_prefix(last_consumed_); + } + + if (last_result_ == INPUT_PENDING) { + StashState(res); + } else if (last_result_ == OK) { + DCHECK(cached_expr_); + if (res != cached_expr_) { + DCHECK(!stash_.empty()); + + *res = *cached_expr_; + } + } + + return last_result_; +} + +void RedisParser::InitStart(uint8_t prefix_b, RespExpr::Vec* res) { + buf_stash_.clear(); + stash_.clear(); + cached_expr_ = res; + parse_stack_.clear(); + last_stashed_level_ = 0; + last_stashed_index_ = 0; + + switch (prefix_b) { + case '$': + state_ = PARSE_ARG_S; + parse_stack_.emplace_back(1, cached_expr_); // expression of length 1. + break; + case '*': + state_ = ARRAY_LEN_S; + break; + default: + state_ = INLINE_S; + break; + } +} + +void RedisParser::StashState(RespExpr::Vec* res) { + if (cached_expr_->empty() && stash_.empty()) { + cached_expr_ = nullptr; + return; + } + + if (cached_expr_ == res) { + stash_.emplace_back(new RespExpr::Vec(*res)); + cached_expr_ = stash_.back().get(); + } + + DCHECK_LT(last_stashed_level_, stash_.size()); + while (true) { + auto& cur = *stash_[last_stashed_level_]; + + for (; last_stashed_index_ < cur.size(); ++last_stashed_index_) { + auto& e = cur[last_stashed_index_]; + if (RespExpr::STRING == e.type) { + Buffer& ebuf = get(e.u); + if (ebuf.empty() && last_stashed_index_ + 1 == cur.size()) + break; + if (!ebuf.empty() && !e.has_support) { + BlobPtr ptr(new uint8_t[ebuf.size()]); + memcpy(ptr.get(), ebuf.data(), ebuf.size()); + ebuf = Buffer{ptr.get(), ebuf.size()}; + buf_stash_.push_back(std::move(ptr)); + e.has_support = true; + } + } + } + + if (last_stashed_level_ + 1 == stash_.size()) + break; + ++last_stashed_level_; + last_stashed_index_ = 0; + } +} + +auto RedisParser::ParseInline(Buffer str) -> Result { + DCHECK(!str.empty()); + + uint8_t* ptr = str.begin(); + uint8_t* end = str.end(); + uint8_t* token_start = ptr; + + if (is_broken_token_) { + while (ptr != end && *ptr > 32) + ++ptr; + + size_t len = ptr - token_start; + + ExtendLastString(Buffer(token_start, len)); + if (ptr != end) { + is_broken_token_ = false; + } + } + + auto is_finish = [&] { return ptr == end || *ptr == '\n'; }; + + while (true) { + while (!is_finish() && *ptr <= 32) { + ++ptr; + } + // We do not test for \r in order to accept 'nc' input. + if (is_finish()) + break; + + DCHECK(!is_broken_token_); + + token_start = ptr; + while (ptr != end && *ptr > 32) + ++ptr; + + cached_expr_->emplace_back(RespExpr::STRING); + cached_expr_->back().u = Buffer{token_start, size_t(ptr - token_start)}; + } + + last_consumed_ = ptr - str.data(); + if (ptr == end) { // we have not finished parsing. + if (ptr[-1] > 32) { + // we stopped in the middle of the token. + is_broken_token_ = true; + } + + return INPUT_PENDING; + } else { + ++last_consumed_; // consume the delimiter as well. + } + state_ = CMD_COMPLETE_S; + + return OK; +} + +auto RedisParser::ParseNum(Buffer str, int64_t* res) -> Result { + if (str.size() < 4) { + return INPUT_PENDING; + } + + char* s = reinterpret_cast(str.data() + 1); + char* pos = reinterpret_cast(memchr(s, '\n', str.size() - 1)); + if (!pos) { + return str.size() < 32 ? INPUT_PENDING : BAD_INT; + } + if (pos[-1] != '\r') { + return BAD_INT; + } + + bool success = absl::SimpleAtoi(std::string_view{s, size_t(pos - s - 1)}, res); + if (!success) { + return BAD_INT; + } + last_consumed_ = (pos - s) + 2; + + return OK; +} + +auto RedisParser::ConsumeArrayLen(Buffer str) -> Result { + int64_t len; + + Result res = ParseNum(str, &len); + switch (res) { + case INPUT_PENDING: + return INPUT_PENDING; + case BAD_INT: + return BAD_ARRAYLEN; + case OK: + if (len < -1 || len > kMaxArrayLen) + return BAD_ARRAYLEN; + break; + default: + LOG(ERROR) << "Unexpected result " << res; + } + + // Already parsed array expression somewhere. Server should accept only single-level expressions. + if (!parse_stack_.empty()) + return BAD_STRING; + + // Similarly if our cached expr is not empty. + if (!cached_expr_->empty()) + return BAD_STRING; + + if (len <= 0) { + cached_expr_->emplace_back(len == -1 ? RespExpr::NIL_ARRAY : RespExpr::ARRAY); + if (len < 0) + cached_expr_->back().u.emplace(nullptr); // nil + else { + static RespVec empty_vec; + cached_expr_->back().u = &empty_vec; + } + state_ = (parse_stack_.empty()) ? CMD_COMPLETE_S : FINISH_ARG_S; + + return OK; + } + + parse_stack_.emplace_back(len, cached_expr_); + DCHECK(cached_expr_->empty()); + state_ = PARSE_ARG_S; + + return OK; +} + +auto RedisParser::ParseArg(Buffer str) -> Result { + char c = str[0]; + if (c == '$') { + int64_t len; + + Result res = ParseNum(str, &len); + switch (res) { + case INPUT_PENDING: + return INPUT_PENDING; + case BAD_INT: + return BAD_ARRAYLEN; + case OK: + if (len < -1 || len > kMaxBulkLen) + return BAD_ARRAYLEN; + break; + default: + LOG(ERROR) << "Unexpected result " << res; + } + + if (len < 0) { + state_ = FINISH_ARG_S; + cached_expr_->emplace_back(RespExpr::NIL); + } else { + cached_expr_->emplace_back(RespExpr::STRING); + bulk_len_ = len; + state_ = BULK_STR_S; + } + cached_expr_->back().u = Buffer{}; + + return OK; + } + + return BAD_BULKLEN; +} + +auto RedisParser::ConsumeBulk(Buffer str) -> Result { + auto& bulk_str = get(cached_expr_->back().u); + + if (str.size() >= bulk_len_ + 2) { + if (str[bulk_len_] != '\r' || str[bulk_len_ + 1] != '\n') { + return BAD_STRING; + } + + if (bulk_len_) { + if (is_broken_token_) { + memcpy(bulk_str.end(), str.data(), bulk_len_); + bulk_str = Buffer{bulk_str.data(), bulk_str.size() + bulk_len_}; + } else { + bulk_str = str.subspan(0, bulk_len_); + } + } + is_broken_token_ = false; + state_ = FINISH_ARG_S; + last_consumed_ = bulk_len_ + 2; + bulk_len_ = 0; + + return OK; + } + + if (str.size() >= 32) { + DCHECK(bulk_len_); + size_t len = std::min(str.size(), bulk_len_); + + if (is_broken_token_) { + memcpy(bulk_str.end(), str.data(), len); + bulk_str = Buffer{bulk_str.data(), bulk_str.size() + len}; + DVLOG(1) << "Extending bulk stash to size " << bulk_str.size(); + } else { + DVLOG(1) << "New bulk stash size " << bulk_len_; + std::unique_ptr nb(new uint8_t[bulk_len_]); + memcpy(nb.get(), str.data(), len); + bulk_str = Buffer{nb.get(), len}; + buf_stash_.emplace_back(move(nb)); + is_broken_token_ = true; + cached_expr_->back().has_support = true; + } + last_consumed_ = len; + bulk_len_ -= len; + } + + return INPUT_PENDING; +} + +void RedisParser::HandleFinishArg() { + DCHECK(!parse_stack_.empty()); + DCHECK_GT(parse_stack_.back().first, 0u); + + while (true) { + --parse_stack_.back().first; + state_ = PARSE_ARG_S; + if (parse_stack_.back().first != 0) + break; + + parse_stack_.pop_back(); // pop 0. + if (parse_stack_.empty()) { + state_ = CMD_COMPLETE_S; + break; + } + cached_expr_ = parse_stack_.back().second; + } +} + +void RedisParser::ExtendLastString(Buffer str) { + DCHECK(!cached_expr_->empty() && cached_expr_->back().type == RespExpr::STRING); + DCHECK(!buf_stash_.empty()); + + Buffer& last_str = get(cached_expr_->back().u); + + DCHECK(last_str.data() == buf_stash_.back().get()); + + std::unique_ptr nb(new uint8_t[last_str.size() + str.size()]); + memcpy(nb.get(), last_str.data(), last_str.size()); + memcpy(nb.get() + last_str.size(), str.data(), str.size()); + last_str = RespExpr::Buffer{nb.get(), last_str.size() + str.size()}; + buf_stash_.back() = std::move(nb); +} + +} // namespace dfly diff --git a/server/redis_parser.h b/server/redis_parser.h new file mode 100644 index 0000000..ecdc1b6 --- /dev/null +++ b/server/redis_parser.h @@ -0,0 +1,91 @@ +// Copyright 2021, Beeri 15. All rights reserved. +// Author: Roman Gershman (romange@gmail.com) +// +#pragma once + +#include + +#include "resp_expr.h" + +namespace dfly { + +/** + * @brief Zero-copy (best-effort) parser. + * + */ +class RedisParser { + public: + enum Result { + OK, + INPUT_PENDING, + BAD_ARRAYLEN, + BAD_BULKLEN, + BAD_STRING, + BAD_INT + }; + using Buffer = RespExpr::Buffer; + + explicit RedisParser() { + } + + /** + * @brief Parses str into res. "consumed" stores number of bytes consumed from str. + * + * A caller should not invalidate str if the parser returns RESP_OK as long as he continues + * accessing res. However, if parser returns MORE_INPUT a caller may discard consumed + * part of str because parser caches the intermediate state internally according to 'consumed' + * result. + * + * Note: A parser does not always guarantee progress, i.e. if a small buffer was passed it may + * returns MORE_INPUT with consumed == 0. + * + */ + + Result Parse(Buffer str, uint32_t* consumed, RespVec* res); + + size_t stash_size() const { return stash_.size(); } + const std::vector>& stash() const { return stash_;} + + private: + void InitStart(uint8_t prefix_b, RespVec* res); + void StashState(RespVec* res); + + // Skips the first character (*). + Result ConsumeArrayLen(Buffer str); + Result ParseArg(Buffer str); + Result ConsumeBulk(Buffer str); + Result ParseInline(Buffer str); + + // Updates last_consumed_ + Result ParseNum(Buffer str, int64_t* res); + void HandleFinishArg(); + void ExtendLastString(Buffer str); + + enum State : uint8_t { + INIT_S = 0, + INLINE_S, + ARRAY_LEN_S, + PARSE_ARG_S, // Parse [$:+-]string\r\n + BULK_STR_S, + FINISH_ARG_S, + CMD_COMPLETE_S, + }; + + State state_ = INIT_S; + Result last_result_ = OK; + + uint32_t last_consumed_ = 0; + uint32_t bulk_len_ = 0; + uint32_t last_stashed_level_ = 0, last_stashed_index_ = 0; + + // expected expression length, pointer to expression vector. + absl::InlinedVector, 4> parse_stack_; + std::vector> stash_; + + using BlobPtr = std::unique_ptr; + std::vector buf_stash_; + RespVec* cached_expr_ = nullptr; + bool is_broken_token_ = false; +}; + +} // namespace dfly diff --git a/server/redis_parser_test.cc b/server/redis_parser_test.cc new file mode 100644 index 0000000..cd3869e --- /dev/null +++ b/server/redis_parser_test.cc @@ -0,0 +1,126 @@ +// Copyright 2021, Beeri 15. All rights reserved. +// Author: Roman Gershman (romange@gmail.com) +// + +#include "server/redis_parser.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "base/gtest.h" +#include "base/logging.h" +#include "server/test_utils.h" + +using namespace testing; +using namespace std; +namespace dfly { + +MATCHER_P(ArrArg, expected, absl::StrCat(negation ? "is not" : "is", " equal to:\n", expected)) { + if (arg.type != RespExpr::ARRAY) { + *result_listener << "\nWrong type: " << arg.type; + return false; + } + size_t exp_sz = expected; + size_t actual = get(arg.u)->size(); + + if (exp_sz != actual) { + *result_listener << "\nActual size: " << actual; + return false; + } + return true; +} + +class RedisParserTest : public testing::Test { + protected: + RedisParser::Result Parse(std::string_view str); + + RedisParser parser_; + RespExpr::Vec args_; + uint32_t consumed_; + + unique_ptr stash_; +}; + +RedisParser::Result RedisParserTest::Parse(std::string_view str) { + stash_.reset(new uint8_t[str.size()]); + auto* ptr = stash_.get(); + memcpy(ptr, str.data(), str.size()); + return parser_.Parse(RedisParser::Buffer{ptr, str.size()}, &consumed_, &args_); +} + +TEST_F(RedisParserTest, Inline) { + RespExpr e{RespExpr::STRING}; + ASSERT_EQ(RespExpr::STRING, e.type); + + const char kCmd1[] = "KEY VAL\r\n"; + + ASSERT_EQ(RedisParser::OK, Parse(kCmd1)); + EXPECT_EQ(strlen(kCmd1), consumed_); + EXPECT_THAT(args_, ElementsAre(StrArg("KEY"), StrArg("VAL"))); + + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("KEY")); + EXPECT_EQ(3, consumed_); + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" FOO ")); + EXPECT_EQ(5, consumed_); + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" BAR")); + EXPECT_EQ(4, consumed_); + ASSERT_EQ(RedisParser::OK, Parse(" \r\n ")); + EXPECT_EQ(3, consumed_); + EXPECT_THAT(args_, ElementsAre(StrArg("KEY"), StrArg("FOO"), StrArg("BAR"))); + + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" 1 2")); + EXPECT_EQ(4, consumed_); + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" 45")); + EXPECT_EQ(3, consumed_); + ASSERT_EQ(RedisParser::OK, Parse("\r\n")); + EXPECT_EQ(2, consumed_); + EXPECT_THAT(args_, ElementsAre(StrArg("1"), StrArg("2"), StrArg("45"))); + + // Empty queries return RESP_OK. + EXPECT_EQ(RedisParser::OK, Parse("\r\n")); + EXPECT_EQ(2, consumed_); +} + +TEST_F(RedisParserTest, InlineEscaping) { + LOG(ERROR) << "TBD: to be compliant with sdssplitargs"; // TODO: +} + +TEST_F(RedisParserTest, Multi1) { + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*1\r\n$")); + EXPECT_EQ(4, consumed_); + + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("$4\r\nMSET")); + EXPECT_EQ(4, consumed_); + + ASSERT_EQ(RedisParser::OK, Parse("MSET\r\n*2\r\n")); + EXPECT_EQ(6, consumed_); + + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*2\r\n$3\r\nKEY\r\n$3\r\nVAL")); + EXPECT_EQ(17, consumed_); + + ASSERT_EQ(RedisParser::OK, Parse("VAL\r\n")); + EXPECT_EQ(5, consumed_); + EXPECT_THAT(args_, ElementsAre("KEY", "VAL")); +} + +TEST_F(RedisParserTest, Multi2) { + const char kFirst[] = "*3\r\n$3\r\nSET\r\n$16\r\nkey:"; + const char kSecond[] = "key:000002273458\r\n$3\r\nVXK"; + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(kFirst)); + ASSERT_EQ(strlen(kFirst) - 4, consumed_); + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(kSecond)); + ASSERT_EQ(strlen(kSecond) - 3, consumed_); + ASSERT_EQ(RedisParser::OK, Parse("VXK\r\n*3\r\n$3\r\nSET")); + EXPECT_THAT(args_, ElementsAre("SET", "key:000002273458", "VXK")); +} + +TEST_F(RedisParserTest, InvalidMult1) { + ASSERT_EQ(RedisParser::BAD_BULKLEN, Parse("*2\r\n$3\r\nFOO\r\nBAR\r\n")); +} + +TEST_F(RedisParserTest, Empty) { + ASSERT_EQ(RedisParser::OK, Parse("*2\r\n$0\r\n\r\n$0\r\n\r\n")); +} + +} // namespace dfly diff --git a/server/resp_expr.cc b/server/resp_expr.cc new file mode 100644 index 0000000..b37da0b --- /dev/null +++ b/server/resp_expr.cc @@ -0,0 +1,74 @@ +// Copyright 2021, Beeri 15. All rights reserved. +// Author: Roman Gershman (romange@gmail.com) +// + +#include "server/resp_expr.h" + +#include "base/logging.h" + +namespace dfly { + +const char* RespExpr::TypeName(Type t) { + switch (t) { + case STRING: + return "string"; + case INT64: + return "int"; + case ARRAY: + return "array"; + case NIL_ARRAY: + return "nil-array"; + case NIL: + return "nil"; + case ERROR: + return "error"; + } + ABSL_INTERNAL_UNREACHABLE; +} + +} // namespace dfly + +namespace std { + +ostream& operator<<(ostream& os, const dfly::RespExpr& e) { + using dfly::RespExpr; + using dfly::ToAbsl; + + switch (e.type) { + case RespExpr::INT64: + os << "i" << get(e.u); + break; + case RespExpr::STRING: + os << "'" << ToAbsl(get(e.u)) << "'"; + break; + case RespExpr::NIL: + os << "nil"; + break; + case RespExpr::NIL_ARRAY: + os << "[]"; + break; + case RespExpr::ARRAY: + os << dfly::RespSpan{*get(e.u)}; + break; + case RespExpr::ERROR: + os << "e(" << ToAbsl(get(e.u)) << ")"; + break; + } + + return os; +} + +ostream& operator<<(ostream& os, dfly::RespSpan ras) { + os << "["; + if (!ras.empty()) { + for (size_t i = 0; i < ras.size() - 1; ++i) { + os << ras[i] << ","; + } + os << ras.back(); + } + os << "]"; + + return os; +} + +} // namespace std \ No newline at end of file diff --git a/server/resp_expr.h b/server/resp_expr.h new file mode 100644 index 0000000..afc95ca --- /dev/null +++ b/server/resp_expr.h @@ -0,0 +1,52 @@ +// Copyright 2021, Beeri 15. All rights reserved. +// Author: Roman Gershman (romange@gmail.com) +// + +#pragma once + +#include +#include + +#include +#include + +namespace dfly { + +class RespExpr { + public: + using Buffer = absl::Span; + + enum Type : uint8_t { STRING, ARRAY, INT64, NIL, NIL_ARRAY, ERROR }; + + using Vec = std::vector; + Type type; + bool has_support; // whether pointers in this item are supported by external storage. + + std::variant u; + + RespExpr(Type t = NIL) : type(t), has_support(false) { + } + + static Buffer buffer(std::string* s) { + return Buffer{reinterpret_cast(s->data()), s->size()}; + } + + Buffer GetBuf() const { return std::get(u); } + + static const char* TypeName(Type t); +}; + +using RespVec = RespExpr::Vec; +using RespSpan = absl::Span; + +inline std::string_view ToAbsl(const absl::Span& s) { + return std::string_view{reinterpret_cast(s.data()), s.size()}; +} + +} // namespace dfly + +namespace std { +ostream& operator<<(ostream& os, const dfly::RespExpr& e); +ostream& operator<<(ostream& os, dfly::RespSpan rspan); + +} // namespace std \ No newline at end of file diff --git a/server/test_utils.cc b/server/test_utils.cc new file mode 100644 index 0000000..91bc956 --- /dev/null +++ b/server/test_utils.cc @@ -0,0 +1,114 @@ +// Copyright 2021, Beeri 15. All rights reserved. +// Author: Roman Gershman (romange@gmail.com) +// + +#include "server/test_utils.h" + +#include + +#include "base/logging.h" +#include "util/uring/uring_pool.h" + +namespace dfly { + +using namespace testing; +using namespace util; +using namespace std; + +bool RespMatcher::MatchAndExplain(const RespExpr& e, MatchResultListener* listener) const { + if (e.type != type_) { + *listener << "\nWrong type: " << e.type; + return false; + } + + if (type_ == RespExpr::STRING || type_ == RespExpr::ERROR) { + RespExpr::Buffer ebuf = e.GetBuf(); + std::string_view actual{reinterpret_cast(ebuf.data()), ebuf.size()}; + + if (type_ == RespExpr::ERROR && !absl::StrContains(actual, exp_str_)) { + *listener << "Actual does not contain '" << exp_str_ << "'"; + return false; + } + if (type_ == RespExpr::STRING && exp_str_ != actual) { + *listener << "\nActual string: " << actual; + return false; + } + } else if (type_ == RespExpr::INT64) { + auto actual = get(e.u); + if (exp_int_ != actual) { + *listener << "\nActual : " << actual << " expected: " << exp_int_; + return false; + } + } else if (type_ == RespExpr::ARRAY) { + size_t len = get(e.u)->size(); + if (len != size_t(exp_int_)) { + *listener << "Actual length " << len << ", expected: " << exp_int_; + return false; + } + } + + return true; +} + +void RespMatcher::DescribeTo(std::ostream* os) const { + *os << "is "; + switch (type_) { + case RespExpr::STRING: + case RespExpr::ERROR: + *os << exp_str_; + break; + + case RespExpr::INT64: + *os << exp_str_; + break; + default: + *os << "TBD"; + break; + } +} + +void RespMatcher::DescribeNegationTo(std::ostream* os) const { + *os << "is not "; +} + +bool RespTypeMatcher::MatchAndExplain(const RespExpr& e, MatchResultListener* listener) const { + if (e.type != type_) { + *listener << "\nWrong type: " << RespExpr::TypeName(e.type); + return false; + } + + return true; +} + +void RespTypeMatcher::DescribeTo(std::ostream* os) const { + *os << "is " << RespExpr::TypeName(type_); +} + +void RespTypeMatcher::DescribeNegationTo(std::ostream* os) const { + *os << "is not " << RespExpr::TypeName(type_); +} + +void PrintTo(const RespExpr::Vec& vec, std::ostream* os) { + *os << "Vec: ["; + if (!vec.empty()) { + for (size_t i = 0; i < vec.size() - 1; ++i) { + *os << vec[i] << ","; + } + *os << vec.back(); + } + *os << "]\n"; +} + +vector ToIntArr(const RespVec& vec) { + vector res; + for (auto a : vec) { + int64_t val; + std::string_view s = ToAbsl(a.GetBuf()); + CHECK(absl::SimpleAtoi(s, &val)) << s; + res.push_back(val); + } + + return res; +} + +} // namespace dfly diff --git a/server/test_utils.h b/server/test_utils.h new file mode 100644 index 0000000..0fa7307 --- /dev/null +++ b/server/test_utils.h @@ -0,0 +1,86 @@ +// Copyright 2021, Beeri 15. All rights reserved. +// Author: Roman Gershman (romange@gmail.com) +// + +#pragma once + +#include + +#include "io/io.h" +#include "server/redis_parser.h" +#include "util/proactor_pool.h" + +namespace dfly { + +class RespMatcher { + public: + RespMatcher(std::string_view val, RespExpr::Type t = RespExpr::STRING) : type_(t), exp_str_(val) { + } + + RespMatcher(int64_t val, RespExpr::Type t = RespExpr::INT64) + : type_(t), exp_int_(val) { + } + + using is_gtest_matcher = void; + + bool MatchAndExplain(const RespExpr& e, testing::MatchResultListener*) const; + + void DescribeTo(std::ostream* os) const; + + void DescribeNegationTo(std::ostream* os) const; + + private: + RespExpr::Type type_; + + std::string exp_str_; + int64_t exp_int_; +}; + +class RespTypeMatcher { + public: + RespTypeMatcher(RespExpr::Type type) : type_(type) { + } + + using is_gtest_matcher = void; + + bool MatchAndExplain(const RespExpr& e, testing::MatchResultListener*) const; + + void DescribeTo(std::ostream* os) const; + + void DescribeNegationTo(std::ostream* os) const; + + private: + RespExpr::Type type_; +}; + +inline ::testing::PolymorphicMatcher StrArg(std::string_view str) { + return ::testing::MakePolymorphicMatcher(RespMatcher(str)); +} + +inline ::testing::PolymorphicMatcher ErrArg(std::string_view str) { + return ::testing::MakePolymorphicMatcher(RespMatcher(str, RespExpr::ERROR)); +} + +inline ::testing::PolymorphicMatcher IntArg(int64_t ival) { + return ::testing::MakePolymorphicMatcher(RespMatcher(ival)); +} + +inline ::testing::PolymorphicMatcher ArrLen(size_t len) { + return ::testing::MakePolymorphicMatcher(RespMatcher(len, RespExpr::ARRAY)); +} + +inline ::testing::PolymorphicMatcher ArgType(RespExpr::Type t) { + return ::testing::MakePolymorphicMatcher(RespTypeMatcher(t)); +} + +inline bool operator==(const RespExpr& left, const char* s) { + return left.type == RespExpr::STRING && ToAbsl(left.GetBuf()) == s; +} + +void PrintTo(const RespExpr::Vec& vec, std::ostream* os); + +MATCHER_P(RespEq, val, "") { + return ::testing::ExplainMatchResult(::testing::ElementsAre(StrArg(val)), arg, result_listener); +} + +} // namespace dfly