diff --git a/server/redis_parser.cc b/server/redis_parser.cc index fa6090a..a5da30a 100644 --- a/server/redis_parser.cc +++ b/server/redis_parser.cc @@ -96,6 +96,9 @@ void RedisParser::InitStart(uint8_t prefix_b, RespExpr::Vec* res) { switch (prefix_b) { case '$': + case ':': + case '+': + case '-': state_ = PARSE_ARG_S; parse_stack_.emplace_back(1, cached_expr_); // expression of length 1. break; @@ -241,12 +244,10 @@ auto RedisParser::ConsumeArrayLen(Buffer str) -> Result { LOG(ERROR) << "Unexpected result " << res; } - // Already parsed array expression somewhere. Server should accept only single-level expressions. - if (!parse_stack_.empty()) + if (parse_stack_.size() > 0 && server_mode_) return BAD_STRING; - // Similarly if our cached expr is not empty. - if (!cached_expr_->empty()) + if (parse_stack_.size() == 0 && !cached_expr_->empty()) return BAD_STRING; if (len <= 0) { @@ -263,7 +264,15 @@ auto RedisParser::ConsumeArrayLen(Buffer str) -> Result { } parse_stack_.emplace_back(len, cached_expr_); - DCHECK(cached_expr_->empty()); + if (!cached_expr_->empty()) { + DCHECK(!server_mode_); + cached_expr_->emplace_back(RespExpr::ARRAY); + stash_.emplace_back(new RespExpr::Vec()); + RespExpr::Vec* arr = stash_.back().get(); + arr->reserve(len); + cached_expr_->back().u = arr; + cached_expr_ = arr; + } state_ = PARSE_ARG_S; return OK; @@ -301,7 +310,47 @@ auto RedisParser::ParseArg(Buffer str) -> Result { return OK; } - return BAD_BULKLEN; + if (server_mode_) { + return BAD_BULKLEN; + } + + if (c == '*') { + return ConsumeArrayLen(str); + } + + char* s = reinterpret_cast(str.data() + 1); + char* eol = reinterpret_cast(memchr(s, '\n', str.size() - 1)); + + if (c == '+' || c == '-') { // Simple string or error. + DCHECK(!server_mode_); + if (!eol) { + return str.size() < 256 ? INPUT_PENDING : BAD_STRING; + } + if (eol[-1] != '\r') + return BAD_STRING; + + cached_expr_->emplace_back(c == '+' ? RespExpr::STRING : RespExpr::ERROR); + cached_expr_->back().u = Buffer{reinterpret_cast(s), size_t((eol - 1) - s)}; + } else if (c == ':') { + DCHECK(!server_mode_); + if (!eol) { + return str.size() < 32 ? INPUT_PENDING : BAD_INT; + } + int64_t ival; + std::string_view tok{s, size_t((eol - s) - 1)}; + + if (eol[-1] != '\r' || !absl::SimpleAtoi(tok, &ival)) + return BAD_INT; + + cached_expr_->emplace_back(RespExpr::INT64); + cached_expr_->back().u = ival; + } else { + return BAD_STRING; + } + + last_consumed_ = (eol - s) + 2; + state_ = FINISH_ARG_S; + return OK; } auto RedisParser::ConsumeBulk(Buffer str) -> Result { diff --git a/server/redis_parser.h b/server/redis_parser.h index 600dd7d..3a52df2 100644 --- a/server/redis_parser.h +++ b/server/redis_parser.h @@ -25,7 +25,7 @@ class RedisParser { }; using Buffer = RespExpr::Buffer; - explicit RedisParser() { + explicit RedisParser(bool server_mode = true) : server_mode_(server_mode) { } /** @@ -43,6 +43,14 @@ class RedisParser { Result Parse(Buffer str, uint32_t* consumed, RespVec* res); + void SetClientMode() { + server_mode_ = false; + } + + size_t parselen_hint() const { + return bulk_len_; + } + size_t stash_size() const { return stash_.size(); } const std::vector>& stash() const { return stash_;} @@ -86,6 +94,7 @@ class RedisParser { std::vector buf_stash_; RespVec* cached_expr_ = nullptr; bool is_broken_token_ = false; + bool server_mode_ = true; }; } // namespace dfly diff --git a/server/redis_parser_test.cc b/server/redis_parser_test.cc index d2e283d..a8e475e 100644 --- a/server/redis_parser_test.cc +++ b/server/redis_parser_test.cc @@ -4,6 +4,10 @@ #include "server/redis_parser.h" +extern "C" { + #include "redis/sds.h" +} + #include #include @@ -82,11 +86,49 @@ TEST_F(RedisParserTest, Inline) { EXPECT_EQ(2, consumed_); } +TEST_F(RedisParserTest, Sds) { + int argc; + sds* argv = sdssplitargs("\r\n",&argc); + EXPECT_EQ(0, argc); + sdsfreesplitres(argv,argc); + + argv = sdssplitargs("\026 \020 \200 \277 \r\n",&argc); + EXPECT_EQ(4, argc); + EXPECT_STREQ("\026", argv[0]); + sdsfreesplitres(argv,argc); + + argv = sdssplitargs(R"(abc "oops\n" )""\r\n",&argc); + EXPECT_EQ(2, argc); + EXPECT_STREQ("oops\n", argv[1]); + sdsfreesplitres(argv,argc); + + argv = sdssplitargs(R"( "abc\xf0" )" "\t'oops\n' \r\n",&argc); + ASSERT_EQ(2, argc); + EXPECT_STREQ("abc\xf0", argv[0]); + EXPECT_STREQ("oops\n", argv[1]); + sdsfreesplitres(argv,argc); +} + TEST_F(RedisParserTest, InlineEscaping) { LOG(ERROR) << "TBD: to be compliant with sdssplitargs"; // TODO: } TEST_F(RedisParserTest, Multi1) { + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*1\r\n")); + EXPECT_EQ(4, consumed_); + EXPECT_EQ(0, parser_.parselen_hint()); + + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("$4\r\n")); + EXPECT_EQ(4, consumed_); + EXPECT_EQ(4, parser_.parselen_hint()); + + ASSERT_EQ(RedisParser::OK, Parse("PING\r\n")); + EXPECT_EQ(6, consumed_); + EXPECT_EQ(0, parser_.parselen_hint()); + EXPECT_THAT(args_, ElementsAre(StrArg("PING"))); +} + +TEST_F(RedisParserTest, Multi2) { ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*1\r\n$")); EXPECT_EQ(4, consumed_); @@ -104,7 +146,7 @@ TEST_F(RedisParserTest, Multi1) { EXPECT_THAT(args_, ElementsAre("KEY", "VAL")); } -TEST_F(RedisParserTest, Multi2) { +TEST_F(RedisParserTest, Multi3) { const char kFirst[] = "*3\r\n$3\r\nSET\r\n$16\r\nkey:"; const char kSecond[] = "key:000002273458\r\n$3\r\nVXK"; ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(kFirst)); @@ -115,6 +157,29 @@ TEST_F(RedisParserTest, Multi2) { EXPECT_THAT(args_, ElementsAre("SET", "key:000002273458", "VXK")); } +TEST_F(RedisParserTest, ClientMode) { + parser_.SetClientMode(); + + ASSERT_EQ(RedisParser::OK, Parse(":-1\r\n")); + EXPECT_THAT(args_, ElementsAre(IntArg(-1))); + + ASSERT_EQ(RedisParser::OK, Parse("+OK\r\n")); + EXPECT_THAT(args_, RespEq("OK")); + + ASSERT_EQ(RedisParser::OK, Parse("-ERR foo bar\r\n")); + EXPECT_THAT(args_, ElementsAre(ErrArg("ERR foo"))); +} + +TEST_F(RedisParserTest, Hierarchy) { + parser_.SetClientMode(); + + const char* kThirdArg = "*2\r\n$3\r\n100\r\n$3\r\n200\r\n"; + string resp = absl::StrCat("*3\r\n$3\r\n900\r\n$3\r\n800\r\n", kThirdArg); + ASSERT_EQ(RedisParser::OK, Parse(resp)); + EXPECT_THAT(args_, ElementsAre(StrArg("900"), StrArg("800"), ArrArg(2))); + EXPECT_THAT(*get(args_[2].u), ElementsAre(StrArg("100"), StrArg("200"))); +} + TEST_F(RedisParserTest, InvalidMult1) { ASSERT_EQ(RedisParser::BAD_BULKLEN, Parse("*2\r\n$3\r\nFOO\r\nBAR\r\n")); } @@ -123,4 +188,28 @@ TEST_F(RedisParserTest, Empty) { ASSERT_EQ(RedisParser::OK, Parse("*2\r\n$0\r\n\r\n$0\r\n\r\n")); } +TEST_F(RedisParserTest, LargeBulk) { + std::string_view prefix("*1\r\n$1024\r\n"); + + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(prefix)); + ASSERT_EQ(prefix.size(), consumed_); + ASSERT_GE(parser_.parselen_hint(), 1024); + + string half(512, 'a'); + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(half)); + ASSERT_EQ(512, consumed_); + ASSERT_GE(parser_.parselen_hint(), 512); + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(half)); + ASSERT_EQ(512, consumed_); + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("\r")); + ASSERT_EQ(0, consumed_); + ASSERT_EQ(RedisParser::OK, Parse("\r\n")); + ASSERT_EQ(2, consumed_); + + string part1 = absl::StrCat(prefix, half); + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(part1)); + ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(half)); + ASSERT_EQ(RedisParser::OK, Parse("\r\n")); +} + } // namespace dfly diff --git a/server/test_utils.cc b/server/test_utils.cc index 5be31ef..6cab2f1 100644 --- a/server/test_utils.cc +++ b/server/test_utils.cc @@ -183,13 +183,6 @@ RespVec BaseFamilyTest::Run(initializer_list list) { last_cmd_dbg_info_ = context.last_command_debug; RespVec vec = conn->ParseResp(); - if (vec.size() == 1) { - auto buf = vec.front().GetBuf(); - if (!buf.empty() && buf[0] == '+') { - buf.remove_prefix(1); - std::get(vec.front().u) = buf; - } - } return vec; } @@ -230,7 +223,7 @@ RespVec BaseFamilyTest::TestConn::ParseResp() { auto buf = RespExpr::buffer(&s); uint32_t consumed = 0; - parser.reset(new RedisParser); + parser.reset(new RedisParser{false}); // Client mode. RespVec res; RedisParser::Result st = parser->Parse(buf, &consumed, &res); CHECK_EQ(RedisParser::OK, st);