From a93940913ba4676fe32f58f33da7b6b5ea300186 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Wed, 23 Feb 2022 23:54:56 +0200 Subject: [PATCH] Add support for incr/decr/quit memcache commands --- README.md | 6 +-- server/dragonfly_connection.cc | 2 + server/main_service.cc | 49 +++++++++++------- server/memcache_parser.cc | 93 ++++++++++++++++++++-------------- server/memcache_parser.h | 12 ++++- server/memcache_parser_test.cc | 25 +++++++++ server/reply_builder.cc | 10 +--- server/reply_builder.h | 10 ++-- server/string_family.cc | 14 ++--- 9 files changed, 142 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index e726642..30078dd 100644 --- a/README.md +++ b/README.md @@ -132,10 +132,10 @@ a distributed log format. - [x] prepend - [ ] delete - [ ] flush_all -- [ ] incr -- [ ] decr +- [x] incr +- [x] decr - [ ] version -- [ ] quit +- [x] quit API 2.0 - [ ] List Family diff --git a/server/dragonfly_connection.cc b/server/dragonfly_connection.cc index 03e9bbe..eb2ceed 100644 --- a/server/dragonfly_connection.cc +++ b/server/dragonfly_connection.cc @@ -392,6 +392,8 @@ auto Connection::ParseMemcache() -> ParserStatus { if (result == MemcacheParser::PARSE_ERROR) { builder->SendError(""); // ERROR. + } else if (result == MemcacheParser::BAD_DELTA) { + builder->SendClientError("invalid numeric delta argument"); } else if (result != MemcacheParser::OK) { builder->SendClientError("bad command line format"); } diff --git a/server/main_service.cc b/server/main_service.cc index 49fe473..7416999 100644 --- a/server/main_service.cc +++ b/server/main_service.cc @@ -56,7 +56,6 @@ class InterpreterReplier : public RedisReplyBuilder { } void SendError(std::string_view str) override; - void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) override; void SendStored() override; void SendSimpleString(std::string_view str) final; @@ -154,11 +153,6 @@ void InterpreterReplier::SendError(string_view str) { explr_->OnError(str); } -void InterpreterReplier::SendGetReply(string_view key, uint32_t flags, string_view value) { - DCHECK(array_len_.empty()); - explr_->OnString(value); -} - void InterpreterReplier::SendStored() { DCHECK(array_len_.empty()); SendSimpleString("OK"); @@ -477,19 +471,28 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va ConnectionContext* cntx) { absl::InlinedVector args; char cmd_name[16]; - char set_opt[4] = {0}; + char store_opt[32] = {0}; + MCReplyBuilder* mc_builder = static_cast(cntx->reply_builder()); switch (cmd.type) { case MemcacheParser::REPLACE: strcpy(cmd_name, "SET"); - strcpy(set_opt, "XX"); + strcpy(store_opt, "XX"); break; case MemcacheParser::SET: strcpy(cmd_name, "SET"); break; case MemcacheParser::ADD: strcpy(cmd_name, "SET"); - strcpy(set_opt, "NX"); + strcpy(store_opt, "NX"); + break; + case MemcacheParser::INCR: + strcpy(cmd_name, "INCRBY"); + absl::numbers_internal::FastIntToBuffer(cmd.delta, store_opt); + break; + case MemcacheParser::DECR: + strcpy(cmd_name, "DECRBY"); + absl::numbers_internal::FastIntToBuffer(cmd.delta, store_opt); break; case MemcacheParser::APPEND: strcpy(cmd_name, "APPEND"); @@ -500,33 +503,42 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va case MemcacheParser::GET: strcpy(cmd_name, "MGET"); break; + case MemcacheParser::QUIT: + strcpy(cmd_name, "QUIT"); + break; case MemcacheParser::STATS: server_family_.StatsMC(cmd.key, cntx); return; default: - mc_builder->SendClientError("bad command line format"); return; } args.emplace_back(cmd_name, strlen(cmd_name)); - char* key = const_cast(cmd.key.data()); - args.emplace_back(key, cmd.key.size()); + + if (!cmd.key.empty()) { + char* key = const_cast(cmd.key.data()); + args.emplace_back(key, cmd.key.size()); + } if (MemcacheParser::IsStoreCmd(cmd.type)) { char* v = const_cast(value.data()); args.emplace_back(v, value.size()); - if (set_opt[0]) { - args.emplace_back(set_opt, strlen(set_opt)); + if (store_opt[0]) { + args.emplace_back(store_opt, strlen(store_opt)); } cntx->conn_state.memcache_flag = cmd.flags; - } else { + } else if (cmd.type < MemcacheParser::QUIT) { // read commands for (auto s : cmd.keys_ext) { char* key = const_cast(s.data()); args.emplace_back(key, s.size()); } + } else { // write commands. + if (store_opt[0]) { + args.emplace_back(store_opt, strlen(store_opt)); + } } DispatchCommand(CmdArgList{args}, cntx); @@ -567,8 +579,11 @@ void Service::RegisterHttp(HttpListenerBase* listener) { } void Service::Quit(CmdArgList args, ConnectionContext* cntx) { - (*cntx)->SendOk(); - (*cntx)->CloseConnection(); + if (cntx->protocol() == Protocol::REDIS) + (*cntx)->SendOk(); + + SinkReplyBuilder* builder = static_cast(cntx->reply_builder()); + builder->CloseConnection(); } void Service::Multi(CmdArgList args, ConnectionContext* cntx) { diff --git a/server/memcache_parser.cc b/server/memcache_parser.cc index 94853e9..c2f4807 100644 --- a/server/memcache_parser.cc +++ b/server/memcache_parser.cc @@ -1,49 +1,52 @@ -// Copyright 2021, Roman Gershman. All rights reserved. +// Copyright 2022, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #include "server/memcache_parser.h" +#include #include #include +#include "base/stl_util.h" + namespace dfly { using namespace std; +using MP = MemcacheParser; namespace { -pair cmd_map[] = { - {"set", MemcacheParser::SET}, {"add", MemcacheParser::ADD}, - {"replace", MemcacheParser::REPLACE}, {"append", MemcacheParser::APPEND}, - {"prepend", MemcacheParser::PREPEND}, {"cas", MemcacheParser::CAS}, - {"get", MemcacheParser::GET}, {"gets", MemcacheParser::GETS}, - {"gat", MemcacheParser::GAT}, {"gats", MemcacheParser::GATS}, - {"stats", MemcacheParser::STATS}, -}; +MP::CmdType From(string_view token) { + static absl::flat_hash_map cmd_map{ + {"set", MP::SET}, {"add", MP::ADD}, {"replace", MP::REPLACE}, + {"append", MP::APPEND}, {"prepend", MP::PREPEND}, {"cas", MP::CAS}, + {"get", MP::GET}, {"gets", MP::GETS}, {"gat", MP::GAT}, + {"gats", MP::GATS}, {"stats", MP::STATS}, {"incr", MP::INCR}, + {"decr", MP::DECR}, {"delete", MP::DELETE}, {"flush_all", MP::FLUSHALL}, + {"quit", MP::QUIT}, + }; -MemcacheParser::CmdType From(string_view token) { - for (const auto& k_v : cmd_map) { - if (token == k_v.first) - return k_v.second; - } - return MemcacheParser::INVALID; + auto it = cmd_map.find(token); + if (it == cmd_map.end()) + return MP::INVALID; + + return it->second; } -MemcacheParser::Result ParseStore(const std::string_view* tokens, unsigned num_tokens, - MemcacheParser::Command* res) { +MP::Result ParseStore(const std::string_view* tokens, unsigned num_tokens, MP::Command* res) { unsigned opt_pos = 3; - if (res->type == MemcacheParser::CAS) { + if (res->type == MP::CAS) { if (num_tokens <= opt_pos) - return MemcacheParser::PARSE_ERROR; + return MP::PARSE_ERROR; ++opt_pos; } uint32_t flags; if (!absl::SimpleAtoi(tokens[0], &flags) || !absl::SimpleAtoi(tokens[1], &res->expire_ts) || !absl::SimpleAtoi(tokens[2], &res->bytes_len)) - return MemcacheParser::BAD_INT; + return MP::BAD_INT; - if (res->type == MemcacheParser::CAS && !absl::SimpleAtoi(tokens[3], &res->cas_unique)) { - return MemcacheParser::BAD_INT; + if (res->type == MP::CAS && !absl::SimpleAtoi(tokens[3], &res->cas_unique)) { + return MP::BAD_INT; } res->flags = flags; @@ -51,39 +54,54 @@ MemcacheParser::Result ParseStore(const std::string_view* tokens, unsigned num_t if (tokens[opt_pos] == "noreply") { res->no_reply = true; } else { - return MemcacheParser::PARSE_ERROR; + return MP::PARSE_ERROR; } } else if (num_tokens > opt_pos + 1) { - return MemcacheParser::PARSE_ERROR; + return MP::PARSE_ERROR; } - return MemcacheParser::OK; + return MP::OK; } -MemcacheParser::Result ParseRetrieve(const std::string_view* tokens, unsigned num_tokens, - MemcacheParser::Command* res) { +MP::Result ParseValueless(const std::string_view* tokens, unsigned num_tokens, MP::Command* res) { unsigned key_pos = 0; - if (res->type == MemcacheParser::GAT || res->type == MemcacheParser::GATS) { + if (res->type == MP::GAT || res->type == MP::GATS) { if (!absl::SimpleAtoi(tokens[0], &res->expire_ts)) { - return MemcacheParser::BAD_INT; + return MP::BAD_INT; } ++key_pos; } res->key = tokens[key_pos++]; - if (res->type == MemcacheParser::STATS && key_pos < num_tokens) - return MemcacheParser::PARSE_ERROR; + if (key_pos < num_tokens && base::_in(res->type, {MP::STATS, MP::FLUSHALL})) + return MP::PARSE_ERROR; // we do not support additional arguments for now. + + if (res->type == MP::INCR || res->type == MP::DECR) { + if (key_pos == num_tokens) + return MP::PARSE_ERROR; + + if (!absl::SimpleAtoi(tokens[key_pos], &res->delta)) + return MP::BAD_DELTA; + ++key_pos; + } while (key_pos < num_tokens) { res->keys_ext.push_back(tokens[key_pos++]); } - return MemcacheParser::OK; + if (res->type >= MP::DELETE) { // write commands + if (!res->keys_ext.empty() && res->keys_ext.back() == "noreply") { + res->no_reply = true; + res->keys_ext.pop_back(); + } + } + + return MP::OK; } } // namespace -auto MemcacheParser::Parse(string_view str, uint32_t* consumed, Command* cmd) -> Result { +auto MP::Parse(string_view str, uint32_t* consumed, Command* cmd) -> Result { auto pos = str.find('\n'); *consumed = 0; if (pos == string_view::npos) { @@ -137,7 +155,7 @@ auto MemcacheParser::Parse(string_view str, uint32_t* consumed, Command* cmd) -> if (cmd->type <= CAS) { // Store command if (num_tokens < 5 || tokens[1].size() > 250) { - return MemcacheParser::PARSE_ERROR; + return MP::PARSE_ERROR; } // memcpy(single_key_, tokens[0].data(), tokens[0].size()); // we copy the key @@ -147,11 +165,12 @@ auto MemcacheParser::Parse(string_view str, uint32_t* consumed, Command* cmd) -> } if (num_tokens == 1) { - if (cmd->type == MemcacheParser::STATS) - return MemcacheParser::OK; + if (base::_in(cmd->type, {MP::STATS, MP::FLUSHALL, MP::QUIT})) + return MP::OK; + return MP::PARSE_ERROR; } - return ParseRetrieve(tokens + 1, num_tokens - 1, cmd); + return ParseValueless(tokens + 1, num_tokens - 1, cmd); }; } // namespace dfly \ No newline at end of file diff --git a/server/memcache_parser.h b/server/memcache_parser.h index f35b448..fe54103 100644 --- a/server/memcache_parser.h +++ b/server/memcache_parser.h @@ -30,10 +30,13 @@ class MemcacheParser { GATS = 13, STATS = 14, - // Delete and INCR + QUIT = 20, + + // The rest of write commands. DELETE = 21, INCR = 22, DECR = 23, + FLUSHALL = 24, }; // According to https://github.com/memcached/memcached/wiki/Commands#standard-protocol @@ -42,7 +45,11 @@ class MemcacheParser { std::string_view key; std::vector keys_ext; - uint64_t cas_unique = 0; + union { + uint64_t cas_unique = 0; // for CAS COMMAND + uint64_t delta; // for DECR/INCR commands. + }; + uint32_t expire_ts = 0; // relative time in seconds. uint32_t bytes_len = 0; uint32_t flags = 0; @@ -55,6 +62,7 @@ class MemcacheParser { UNKNOWN_CMD, BAD_INT, PARSE_ERROR, + BAD_DELTA, }; static bool IsStoreCmd(CmdType type) { diff --git a/server/memcache_parser_test.cc b/server/memcache_parser_test.cc index deaa33a..0614c57 100644 --- a/server/memcache_parser_test.cc +++ b/server/memcache_parser_test.cc @@ -34,6 +34,31 @@ TEST_F(MCParserTest, Basic) { EXPECT_EQ(1, cmd_.flags); EXPECT_EQ(20, cmd_.expire_ts); EXPECT_EQ(3, cmd_.bytes_len); + EXPECT_EQ(MemcacheParser::SET, cmd_.type); + + st = parser_.Parse("quit\r\n", &consumed_, &cmd_); + EXPECT_EQ(MemcacheParser::OK, st); + EXPECT_EQ(MemcacheParser::QUIT, cmd_.type); +} + +TEST_F(MCParserTest, Incr) { + MemcacheParser::Result st = parser_.Parse("incr a\r\n", &consumed_, &cmd_); + EXPECT_EQ(MemcacheParser::PARSE_ERROR, st); + + st = parser_.Parse("incr a 1\r\n", &consumed_, &cmd_); + EXPECT_EQ(MemcacheParser::OK, st); + EXPECT_EQ(MemcacheParser::INCR, cmd_.type); + EXPECT_EQ("a", cmd_.key); + EXPECT_EQ(1, cmd_.delta); + EXPECT_FALSE(cmd_.no_reply); + + st = parser_.Parse("incr a -1\r\n", &consumed_, &cmd_); + EXPECT_EQ(MemcacheParser::BAD_DELTA, st); + + st = parser_.Parse("decr b 10 noreply\r\n", &consumed_, &cmd_); + EXPECT_EQ(MemcacheParser::OK, st); + EXPECT_EQ(MemcacheParser::DECR, cmd_.type); + EXPECT_EQ(10, cmd_.delta); } TEST_F(MCParserTest, Stats) { diff --git a/server/reply_builder.cc b/server/reply_builder.cc index 5f8a3fe..093fa59 100644 --- a/server/reply_builder.cc +++ b/server/reply_builder.cc @@ -88,10 +88,8 @@ void MCReplyBuilder::SendStored() { SendDirect("STORED\r\n"); } -void MCReplyBuilder::SendGetReply(std::string_view key, uint32_t flags, std::string_view value) { - string first = absl::StrCat("VALUE ", key, " ", flags, " ", value.size(), "\r\n"); - iovec v[] = {IoVec(first), IoVec(value), IoVec(kCRLF)}; - Send(v, ABSL_ARRAYSIZE(v)); +void MCReplyBuilder::SendLong(long val) { + SendDirect(absl::StrCat(val, kCRLF)); } void MCReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) { @@ -140,10 +138,6 @@ void RedisReplyBuilder::SendError(string_view str) { } } -void RedisReplyBuilder::SendGetReply(std::string_view key, uint32_t flags, std::string_view value) { - SendBulkString(value); -} - void RedisReplyBuilder::SendStored() { SendSimpleString("OK"); } diff --git a/server/reply_builder.h b/server/reply_builder.h index 7cd32e0..146aeca 100644 --- a/server/reply_builder.h +++ b/server/reply_builder.h @@ -22,8 +22,6 @@ class ReplyBuilderInterface { virtual std::error_code GetError() const = 0; - virtual void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) = 0; - struct ResponseValue { std::string_view key; std::string value; @@ -34,6 +32,7 @@ class ReplyBuilderInterface { using OptResp = std::optional; virtual void SendMGetResponse(const OptResp* resp, uint32_t count) = 0; + virtual void SendLong(long val) = 0; virtual void SendSetSkipped() = 0; }; @@ -92,11 +91,11 @@ class MCReplyBuilder : public SinkReplyBuilder { MCReplyBuilder(::io::Sink* stream); void SendError(std::string_view str) final; - void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) final; + // void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) final; void SendMGetResponse(const OptResp* resp, uint32_t count) final; void SendStored() final; - + void SendLong(long val) final; void SendSetSkipped() final; void SendClientError(std::string_view str); @@ -111,10 +110,10 @@ class RedisReplyBuilder : public SinkReplyBuilder { } void SendError(std::string_view str) override; - void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) override; void SendMGetResponse(const OptResp* resp, uint32_t count) override; void SendStored() override; + void SendLong(long val) override; void SendSetSkipped() override; void SendError(OpStatus status); @@ -126,7 +125,6 @@ class RedisReplyBuilder : public SinkReplyBuilder { virtual void SendStringArr(absl::Span arr); virtual void SendNull(); - virtual void SendLong(long val); virtual void SendDouble(double val); virtual void SendBulkString(std::string_view str); diff --git a/server/string_family.cc b/server/string_family.cc index 5aa352d..44cafc6 100644 --- a/server/string_family.cc +++ b/server/string_family.cc @@ -162,7 +162,6 @@ void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) { get_qps.Inc(); std::string_view key = ArgS(args, 1); - uint32_t mc_flag = 0; auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult { OpResult it_res = shard->db_slice().Find(t->db_index(), key, OBJ_STRING); @@ -180,7 +179,7 @@ void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) { if (result) { DVLOG(1) << "GET " << trans->DebugId() << ": " << key << " " << result.value(); - (*cntx)->SendGetReply(key, mc_flag, result.value()); + (*cntx)->SendBulkString(*result); } else { switch (result.status()) { case OpStatus::WRONG_TYPE: @@ -215,7 +214,7 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) { } if (prev_val) { - (*cntx)->SendGetReply(key, 0, *prev_val); + (*cntx)->SendBulkString(*prev_val); return; } return (*cntx)->SendNull(); @@ -270,15 +269,18 @@ void StringFamily::IncrByGeneric(std::string_view key, int64_t val, ConnectionCo }; OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + auto* builder = cntx->reply_builder(); DVLOG(2) << "IncrByGeneric " << key << "/" << result.value(); switch (result.status()) { case OpStatus::OK: - return (*cntx)->SendLong(result.value()); + return builder->SendLong(result.value()); case OpStatus::INVALID_VALUE: - return (*cntx)->SendError(kInvalidIntErr); + return builder->SendError(kInvalidIntErr); case OpStatus::OUT_OF_RANGE: - return (*cntx)->SendError("increment or decrement would overflow"); + return builder->SendError("increment or decrement would overflow"); + case OpStatus::WRONG_TYPE: + return builder->SendError(kWrongTypeErr); default:; } __builtin_unreachable();