From c8fe7ba28b2dbd2660c243bb83244adabbd7702d Mon Sep 17 00:00:00 2001
From: Roman Gershman <roman@dragonflydb.io>
Date: Tue, 28 Jun 2022 19:51:25 +0300
Subject: [PATCH] fix(server): Fix serialization logic when returning an empty
 array.

1. Fix SendStringArr logic. Consolidate both string_view and string variants to using the same code.
2. Tighten the test framework so that it will test that the parser consumed the whole response.

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
---
 src/facade/reply_builder.cc   | 53 +++++++++++++++++++++--------------
 src/facade/reply_builder.h    |  3 ++
 src/server/dragonfly_test.cc  |  4 +++
 src/server/hset_family.cc     | 13 ++++++---
 src/server/set_family_test.cc |  5 ++++
 src/server/stream_family.cc   |  1 +
 src/server/test_utils.cc      | 10 ++++---
 src/server/test_utils.h       |  1 +
 8 files changed, 61 insertions(+), 29 deletions(-)

diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc
index 3014691..08fc082 100644
--- a/src/facade/reply_builder.cc
+++ b/src/facade/reply_builder.cc
@@ -297,13 +297,12 @@ void RedisReplyBuilder::SendNullArray() {
 }
 
 void RedisReplyBuilder::SendStringArr(absl::Span<const std::string_view> arr) {
-  string res = absl::StrCat("*", arr.size(), kCRLF);
-
-  for (size_t i = 0; i < arr.size(); ++i) {
-    StrAppend(&res, "$", arr[i].size(), kCRLF);
-    res.append(arr[i]).append(kCRLF);
+  if (arr.empty()) {
+    SendRaw("*0\r\n");
+    return;
   }
-  SendRaw(res);
+
+  SendStringArr(arr.data(), arr.size());
 }
 
 // This implementation a bit complicated because it uses vectorized
@@ -312,44 +311,59 @@ void RedisReplyBuilder::SendStringArr(absl::Span<const std::string_view> arr) {
 // We limit the vector length to 256 and when it fills up we flush it to the socket and continue
 // iterating.
 void RedisReplyBuilder::SendStringArr(absl::Span<const string> arr) {
+  if (arr.empty()) {
+    SendRaw("*0\r\n");
+    return;
+  }
+  SendStringArr(arr.data(), arr.size());
+}
+
+void RedisReplyBuilder::StartArray(unsigned len) {
+  SendRaw(absl::StrCat("*", len, kCRLF));
+}
+
+void RedisReplyBuilder::SendStringArr(StrPtr str_ptr, uint32_t len) {
   // When vector length is too long, Send returns EMSGSIZE.
-  size_t vec_len = std::min<size_t>(256u, arr.size());
+  size_t vec_len = std::min<size_t>(256u, len);
 
   absl::FixedArray<iovec, 16> vec(vec_len * 2 + 2);
   absl::FixedArray<char, 64> meta((vec_len + 1) * 16);
   char* next = meta.data();
 
   *next++ = '*';
-  next = absl::numbers_internal::FastIntToBuffer(arr.size(), next);
+  next = absl::numbers_internal::FastIntToBuffer(len, next);
   *next++ = '\r';
   *next++ = '\n';
-  vec[0].iov_base = meta.data();
-  vec[0].iov_len = next - meta.data();
+  vec[0] = IoVec(string_view{meta.data(), size_t(next - meta.data())});
   char* start = next;
 
   unsigned vec_indx = 1;
-  for (unsigned i = 0; i < arr.size(); ++i) {
-    const auto& src = arr[i];
+  string_view src;
+  for (unsigned i = 0; i < len; ++i) {
+
+    if (holds_alternative<const string_view*>(str_ptr)) {
+      src = get<const string_view*>(str_ptr)[i];
+    } else {
+      src = get<const string*>(str_ptr)[i];
+    }
     *next++ = '$';
     next = absl::numbers_internal::FastIntToBuffer(src.size(), next);
     *next++ = '\r';
     *next++ = '\n';
-    vec[vec_indx].iov_base = start;
-    vec[vec_indx].iov_len = next - start;
+    vec[vec_indx] = IoVec(string_view{start, size_t(next - start)});
     DCHECK_GT(next - start, 0);
 
     start = next;
     ++vec_indx;
 
-    vec[vec_indx].iov_base = const_cast<char*>(src.data());
-    vec[vec_indx].iov_len = src.size();
+    vec[vec_indx] = IoVec(src);
 
     *next++ = '\r';
     *next++ = '\n';
     ++vec_indx;
 
     if (vec_indx + 1 >= vec.size()) {
-      if (i < arr.size() - 1 || vec_indx == vec.size()) {
+      if (i < len - 1 || vec_indx == vec.size()) {
         Send(vec.data(), vec_indx);
         if (ec_)
           return;
@@ -362,15 +376,12 @@ void RedisReplyBuilder::SendStringArr(absl::Span<const string> arr) {
       }
     }
   }
+
   vec[vec_indx].iov_base = start;
   vec[vec_indx].iov_len = 2;
   Send(vec.data(), vec_indx + 1);
 }
 
-void RedisReplyBuilder::StartArray(unsigned len) {
-  SendRaw(absl::StrCat("*", len, kCRLF));
-}
-
 void ReqSerializer::SendCommand(std::string_view str) {
   VLOG(1) << "SendCommand: " << str;
 
diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h
index 99a58a9..499a943 100644
--- a/src/facade/reply_builder.h
+++ b/src/facade/reply_builder.h
@@ -143,6 +143,9 @@ class RedisReplyBuilder : public SinkReplyBuilder {
   static char* FormatDouble(double val, char* dest, unsigned dest_len);
 
  private:
+
+  using StrPtr = std::variant<const std::string_view*, const std::string*>;
+  void SendStringArr(StrPtr str_ptr, uint32_t len);
 };
 
 class ReqSerializer {
diff --git a/src/server/dragonfly_test.cc b/src/server/dragonfly_test.cc
index 92d6548..b5f6773 100644
--- a/src/server/dragonfly_test.cc
+++ b/src/server/dragonfly_test.cc
@@ -478,6 +478,7 @@ TEST_F(DflyEngineTest, OOM) {
 }
 
 TEST_F(DflyEngineTest, PSubscribe) {
+  single_response_ = false;
   auto resp = pp_->at(1)->Await([&] { return Run({"psubscribe", "a*", "b*"}); });
   EXPECT_THAT(resp, ArrLen(3));
   resp = pp_->at(0)->Await([&] { return Run({"publish", "ab", "foo"}); });
@@ -498,6 +499,8 @@ TEST_F(DflyEngineTest, Unsubscribe) {
   resp = Run({"unsubscribe"});
   EXPECT_THAT(resp.GetVec(), ElementsAre("unsubscribe", ArgType(RespExpr::NIL), IntArg(0)));
 
+  single_response_ = false;
+
   Run({"subscribe", "a", "b"});
 
   resp = Run({"unsubscribe", "a"});
@@ -514,6 +517,7 @@ TEST_F(DflyEngineTest, PUnsubscribe) {
   resp = Run({"punsubscribe"});
   EXPECT_THAT(resp.GetVec(), ElementsAre("punsubscribe", ArgType(RespExpr::NIL), IntArg(0)));
 
+  single_response_ = false;
   Run({"psubscribe", "a*", "b*"});
 
   resp = Run({"punsubscribe", "a*"});
diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc
index f40efaf..b3ed44f 100644
--- a/src/server/hset_family.cc
+++ b/src/server/hset_family.cc
@@ -693,26 +693,31 @@ OpResult<vector<string>> HSetFamily::OpGetAll(const OpArgs& op_args, string_view
   hashTypeIterator* hi = hashTypeInitIterator(hset);
 
   vector<string> res;
+  bool keyval = (mask == (FIELDS | VALUES));
+  size_t len = hashTypeLength(hset);
+  res.resize(keyval ? len * 2 : len);
+  unsigned index = 0;
+
   if (hset->encoding == OBJ_ENCODING_LISTPACK) {
     while (hashTypeNext(hi) != C_ERR) {
       if (mask & FIELDS) {
-        res.push_back(LpGetVal(hi->fptr));
+        res[index++] = LpGetVal(hi->fptr);
       }
 
       if (mask & VALUES) {
-        res.push_back(LpGetVal(hi->vptr));
+        res[index++] = LpGetVal(hi->vptr);
       }
     }
   } else {
     while (hashTypeNext(hi) != C_ERR) {
       if (mask & FIELDS) {
         sds key = (sds)dictGetKey(hi->de);
-        res.emplace_back(key, sdslen(key));
+        res[index++].assign(key, sdslen(key));
       }
 
       if (mask & VALUES) {
         sds val = (sds)dictGetVal(hi->de);
-        res.emplace_back(val, sdslen(val));
+        res[index++].assign(val, sdslen(val));
       }
     }
   }
diff --git a/src/server/set_family_test.cc b/src/server/set_family_test.cc
index 79b3a95..c70d73e 100644
--- a/src/server/set_family_test.cc
+++ b/src/server/set_family_test.cc
@@ -132,4 +132,9 @@ TEST_F(SetFamilyTest, SPop) {
   EXPECT_THAT(resp.GetVec(), IsSubsetOf({"a", "b", "c"}));
 }
 
+TEST_F(SetFamilyTest, Empty) {
+  auto resp = Run({"smembers", "x"});
+  ASSERT_THAT(resp, ArrLen(0));
+}
+
 }  // namespace dfly
diff --git a/src/server/stream_family.cc b/src/server/stream_family.cc
index fab9586..f8ca845 100644
--- a/src/server/stream_family.cc
+++ b/src/server/stream_family.cc
@@ -852,6 +852,7 @@ void StreamFamily::XRangeGeneric(CmdArgList args, bool is_rev, ConnectionContext
         (*cntx)->SendBulkString(k_v.second);
       }
     }
+    return;
   }
 
   if (result.status() == OpStatus::KEY_NOTFOUND) {
diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc
index ba55e2b..26a7768 100644
--- a/src/server/test_utils.cc
+++ b/src/server/test_utils.cc
@@ -65,7 +65,7 @@ class BaseFamilyTest::TestConnWrapper {
 
   CmdArgVec Args(ArgSlice list);
 
-  RespVec ParseResponse();
+  RespVec ParseResponse(bool fully_consumed);
 
   // returns: type(pmessage), pattern, channel, message.
   facade::Connection::PubMessage GetPubMessage(size_t index) const;
@@ -176,7 +176,7 @@ RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) {
   unique_lock lk(mu_);
   last_cmd_dbg_info_ = context->last_command_debug;
 
-  RespVec vec = conn_wrapper->ParseResponse();
+  RespVec vec = conn_wrapper->ParseResponse(single_response_);
   if (vec.size() == 1)
     return vec.front();
   RespVec* new_vec = new RespVec(vec);
@@ -298,7 +298,7 @@ CmdArgVec BaseFamilyTest::TestConnWrapper::Args(ArgSlice list) {
   return res;
 }
 
-RespVec BaseFamilyTest::TestConnWrapper::ParseResponse() {
+RespVec BaseFamilyTest::TestConnWrapper::ParseResponse(bool fully_consumed) {
   tmp_str_vec_.emplace_back(new string{sink_.str()});
   auto& s = *tmp_str_vec_.back();
   auto buf = RespExpr::buffer(&s);
@@ -308,7 +308,9 @@ RespVec BaseFamilyTest::TestConnWrapper::ParseResponse() {
   RespVec res;
   RedisParser::Result st = parser_->Parse(buf, &consumed, &res);
   CHECK_EQ(RedisParser::OK, st);
-
+  if (fully_consumed) {
+    DCHECK_EQ(consumed, s.size()) << s;
+  }
   return res;
 }
 
diff --git a/src/server/test_utils.h b/src/server/test_utils.h
index fd11c9f..d10ba9f 100644
--- a/src/server/test_utils.h
+++ b/src/server/test_utils.h
@@ -87,6 +87,7 @@ class BaseFamilyTest : public ::testing::Test {
   ConnectionContext::DebugInfo last_cmd_dbg_info_;
   uint64_t expire_now_;
   std::vector<RespVec*> resp_vec_;
+  bool single_response_ = true;
 };
 
 }  // namespace dfly