Add support for incr/decr/quit memcache commands

This commit is contained in:
Roman Gershman 2022-02-23 23:54:56 +02:00
parent fcb58efe15
commit a93940913b
9 changed files with 142 additions and 79 deletions

View File

@ -132,10 +132,10 @@ a distributed log format.
- [x] prepend - [x] prepend
- [ ] delete - [ ] delete
- [ ] flush_all - [ ] flush_all
- [ ] incr - [x] incr
- [ ] decr - [x] decr
- [ ] version - [ ] version
- [ ] quit - [x] quit
API 2.0 API 2.0
- [ ] List Family - [ ] List Family

View File

@ -392,6 +392,8 @@ auto Connection::ParseMemcache() -> ParserStatus {
if (result == MemcacheParser::PARSE_ERROR) { if (result == MemcacheParser::PARSE_ERROR) {
builder->SendError(""); // ERROR. builder->SendError(""); // ERROR.
} else if (result == MemcacheParser::BAD_DELTA) {
builder->SendClientError("invalid numeric delta argument");
} else if (result != MemcacheParser::OK) { } else if (result != MemcacheParser::OK) {
builder->SendClientError("bad command line format"); builder->SendClientError("bad command line format");
} }

View File

@ -56,7 +56,6 @@ class InterpreterReplier : public RedisReplyBuilder {
} }
void SendError(std::string_view str) override; void SendError(std::string_view str) override;
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) override;
void SendStored() override; void SendStored() override;
void SendSimpleString(std::string_view str) final; void SendSimpleString(std::string_view str) final;
@ -154,11 +153,6 @@ void InterpreterReplier::SendError(string_view str) {
explr_->OnError(str); explr_->OnError(str);
} }
void InterpreterReplier::SendGetReply(string_view key, uint32_t flags, string_view value) {
DCHECK(array_len_.empty());
explr_->OnString(value);
}
void InterpreterReplier::SendStored() { void InterpreterReplier::SendStored() {
DCHECK(array_len_.empty()); DCHECK(array_len_.empty());
SendSimpleString("OK"); SendSimpleString("OK");
@ -477,19 +471,28 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
ConnectionContext* cntx) { ConnectionContext* cntx) {
absl::InlinedVector<MutableSlice, 8> args; absl::InlinedVector<MutableSlice, 8> args;
char cmd_name[16]; char cmd_name[16];
char set_opt[4] = {0}; char store_opt[32] = {0};
MCReplyBuilder* mc_builder = static_cast<MCReplyBuilder*>(cntx->reply_builder()); MCReplyBuilder* mc_builder = static_cast<MCReplyBuilder*>(cntx->reply_builder());
switch (cmd.type) { switch (cmd.type) {
case MemcacheParser::REPLACE: case MemcacheParser::REPLACE:
strcpy(cmd_name, "SET"); strcpy(cmd_name, "SET");
strcpy(set_opt, "XX"); strcpy(store_opt, "XX");
break; break;
case MemcacheParser::SET: case MemcacheParser::SET:
strcpy(cmd_name, "SET"); strcpy(cmd_name, "SET");
break; break;
case MemcacheParser::ADD: case MemcacheParser::ADD:
strcpy(cmd_name, "SET"); strcpy(cmd_name, "SET");
strcpy(set_opt, "NX"); strcpy(store_opt, "NX");
break;
case MemcacheParser::INCR:
strcpy(cmd_name, "INCRBY");
absl::numbers_internal::FastIntToBuffer(cmd.delta, store_opt);
break;
case MemcacheParser::DECR:
strcpy(cmd_name, "DECRBY");
absl::numbers_internal::FastIntToBuffer(cmd.delta, store_opt);
break; break;
case MemcacheParser::APPEND: case MemcacheParser::APPEND:
strcpy(cmd_name, "APPEND"); strcpy(cmd_name, "APPEND");
@ -500,33 +503,42 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
case MemcacheParser::GET: case MemcacheParser::GET:
strcpy(cmd_name, "MGET"); strcpy(cmd_name, "MGET");
break; break;
case MemcacheParser::QUIT:
strcpy(cmd_name, "QUIT");
break;
case MemcacheParser::STATS: case MemcacheParser::STATS:
server_family_.StatsMC(cmd.key, cntx); server_family_.StatsMC(cmd.key, cntx);
return; return;
default: default:
mc_builder->SendClientError("bad command line format"); mc_builder->SendClientError("bad command line format");
return; return;
} }
args.emplace_back(cmd_name, strlen(cmd_name)); args.emplace_back(cmd_name, strlen(cmd_name));
char* key = const_cast<char*>(cmd.key.data());
args.emplace_back(key, cmd.key.size()); if (!cmd.key.empty()) {
char* key = const_cast<char*>(cmd.key.data());
args.emplace_back(key, cmd.key.size());
}
if (MemcacheParser::IsStoreCmd(cmd.type)) { if (MemcacheParser::IsStoreCmd(cmd.type)) {
char* v = const_cast<char*>(value.data()); char* v = const_cast<char*>(value.data());
args.emplace_back(v, value.size()); args.emplace_back(v, value.size());
if (set_opt[0]) { if (store_opt[0]) {
args.emplace_back(set_opt, strlen(set_opt)); args.emplace_back(store_opt, strlen(store_opt));
} }
cntx->conn_state.memcache_flag = cmd.flags; cntx->conn_state.memcache_flag = cmd.flags;
} else { } else if (cmd.type < MemcacheParser::QUIT) { // read commands
for (auto s : cmd.keys_ext) { for (auto s : cmd.keys_ext) {
char* key = const_cast<char*>(s.data()); char* key = const_cast<char*>(s.data());
args.emplace_back(key, s.size()); args.emplace_back(key, s.size());
} }
} else { // write commands.
if (store_opt[0]) {
args.emplace_back(store_opt, strlen(store_opt));
}
} }
DispatchCommand(CmdArgList{args}, cntx); DispatchCommand(CmdArgList{args}, cntx);
@ -567,8 +579,11 @@ void Service::RegisterHttp(HttpListenerBase* listener) {
} }
void Service::Quit(CmdArgList args, ConnectionContext* cntx) { void Service::Quit(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendOk(); if (cntx->protocol() == Protocol::REDIS)
(*cntx)->CloseConnection(); (*cntx)->SendOk();
SinkReplyBuilder* builder = static_cast<SinkReplyBuilder*>(cntx->reply_builder());
builder->CloseConnection();
} }
void Service::Multi(CmdArgList args, ConnectionContext* cntx) { void Service::Multi(CmdArgList args, ConnectionContext* cntx) {

View File

@ -1,49 +1,52 @@
// Copyright 2021, Roman Gershman. All rights reserved. // Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms. // See LICENSE for licensing terms.
// //
#include "server/memcache_parser.h" #include "server/memcache_parser.h"
#include <absl/container/flat_hash_map.h>
#include <absl/strings/ascii.h> #include <absl/strings/ascii.h>
#include <absl/strings/numbers.h> #include <absl/strings/numbers.h>
#include "base/stl_util.h"
namespace dfly { namespace dfly {
using namespace std; using namespace std;
using MP = MemcacheParser;
namespace { namespace {
pair<string_view, MemcacheParser::CmdType> cmd_map[] = { MP::CmdType From(string_view token) {
{"set", MemcacheParser::SET}, {"add", MemcacheParser::ADD}, static absl::flat_hash_map<string_view, MP::CmdType> cmd_map{
{"replace", MemcacheParser::REPLACE}, {"append", MemcacheParser::APPEND}, {"set", MP::SET}, {"add", MP::ADD}, {"replace", MP::REPLACE},
{"prepend", MemcacheParser::PREPEND}, {"cas", MemcacheParser::CAS}, {"append", MP::APPEND}, {"prepend", MP::PREPEND}, {"cas", MP::CAS},
{"get", MemcacheParser::GET}, {"gets", MemcacheParser::GETS}, {"get", MP::GET}, {"gets", MP::GETS}, {"gat", MP::GAT},
{"gat", MemcacheParser::GAT}, {"gats", MemcacheParser::GATS}, {"gats", MP::GATS}, {"stats", MP::STATS}, {"incr", MP::INCR},
{"stats", MemcacheParser::STATS}, {"decr", MP::DECR}, {"delete", MP::DELETE}, {"flush_all", MP::FLUSHALL},
}; {"quit", MP::QUIT},
};
MemcacheParser::CmdType From(string_view token) { auto it = cmd_map.find(token);
for (const auto& k_v : cmd_map) { if (it == cmd_map.end())
if (token == k_v.first) return MP::INVALID;
return k_v.second;
} return it->second;
return MemcacheParser::INVALID;
} }
MemcacheParser::Result ParseStore(const std::string_view* tokens, unsigned num_tokens, MP::Result ParseStore(const std::string_view* tokens, unsigned num_tokens, MP::Command* res) {
MemcacheParser::Command* res) {
unsigned opt_pos = 3; unsigned opt_pos = 3;
if (res->type == MemcacheParser::CAS) { if (res->type == MP::CAS) {
if (num_tokens <= opt_pos) if (num_tokens <= opt_pos)
return MemcacheParser::PARSE_ERROR; return MP::PARSE_ERROR;
++opt_pos; ++opt_pos;
} }
uint32_t flags; uint32_t flags;
if (!absl::SimpleAtoi(tokens[0], &flags) || !absl::SimpleAtoi(tokens[1], &res->expire_ts) || if (!absl::SimpleAtoi(tokens[0], &flags) || !absl::SimpleAtoi(tokens[1], &res->expire_ts) ||
!absl::SimpleAtoi(tokens[2], &res->bytes_len)) !absl::SimpleAtoi(tokens[2], &res->bytes_len))
return MemcacheParser::BAD_INT; return MP::BAD_INT;
if (res->type == MemcacheParser::CAS && !absl::SimpleAtoi(tokens[3], &res->cas_unique)) { if (res->type == MP::CAS && !absl::SimpleAtoi(tokens[3], &res->cas_unique)) {
return MemcacheParser::BAD_INT; return MP::BAD_INT;
} }
res->flags = flags; res->flags = flags;
@ -51,39 +54,54 @@ MemcacheParser::Result ParseStore(const std::string_view* tokens, unsigned num_t
if (tokens[opt_pos] == "noreply") { if (tokens[opt_pos] == "noreply") {
res->no_reply = true; res->no_reply = true;
} else { } else {
return MemcacheParser::PARSE_ERROR; return MP::PARSE_ERROR;
} }
} else if (num_tokens > opt_pos + 1) { } else if (num_tokens > opt_pos + 1) {
return MemcacheParser::PARSE_ERROR; return MP::PARSE_ERROR;
} }
return MemcacheParser::OK; return MP::OK;
} }
MemcacheParser::Result ParseRetrieve(const std::string_view* tokens, unsigned num_tokens, MP::Result ParseValueless(const std::string_view* tokens, unsigned num_tokens, MP::Command* res) {
MemcacheParser::Command* res) {
unsigned key_pos = 0; unsigned key_pos = 0;
if (res->type == MemcacheParser::GAT || res->type == MemcacheParser::GATS) { if (res->type == MP::GAT || res->type == MP::GATS) {
if (!absl::SimpleAtoi(tokens[0], &res->expire_ts)) { if (!absl::SimpleAtoi(tokens[0], &res->expire_ts)) {
return MemcacheParser::BAD_INT; return MP::BAD_INT;
} }
++key_pos; ++key_pos;
} }
res->key = tokens[key_pos++]; res->key = tokens[key_pos++];
if (res->type == MemcacheParser::STATS && key_pos < num_tokens) if (key_pos < num_tokens && base::_in(res->type, {MP::STATS, MP::FLUSHALL}))
return MemcacheParser::PARSE_ERROR; return MP::PARSE_ERROR; // we do not support additional arguments for now.
if (res->type == MP::INCR || res->type == MP::DECR) {
if (key_pos == num_tokens)
return MP::PARSE_ERROR;
if (!absl::SimpleAtoi(tokens[key_pos], &res->delta))
return MP::BAD_DELTA;
++key_pos;
}
while (key_pos < num_tokens) { while (key_pos < num_tokens) {
res->keys_ext.push_back(tokens[key_pos++]); res->keys_ext.push_back(tokens[key_pos++]);
} }
return MemcacheParser::OK; if (res->type >= MP::DELETE) { // write commands
if (!res->keys_ext.empty() && res->keys_ext.back() == "noreply") {
res->no_reply = true;
res->keys_ext.pop_back();
}
}
return MP::OK;
} }
} // namespace } // namespace
auto MemcacheParser::Parse(string_view str, uint32_t* consumed, Command* cmd) -> Result { auto MP::Parse(string_view str, uint32_t* consumed, Command* cmd) -> Result {
auto pos = str.find('\n'); auto pos = str.find('\n');
*consumed = 0; *consumed = 0;
if (pos == string_view::npos) { if (pos == string_view::npos) {
@ -137,7 +155,7 @@ auto MemcacheParser::Parse(string_view str, uint32_t* consumed, Command* cmd) ->
if (cmd->type <= CAS) { // Store command if (cmd->type <= CAS) { // Store command
if (num_tokens < 5 || tokens[1].size() > 250) { if (num_tokens < 5 || tokens[1].size() > 250) {
return MemcacheParser::PARSE_ERROR; return MP::PARSE_ERROR;
} }
// memcpy(single_key_, tokens[0].data(), tokens[0].size()); // we copy the key // memcpy(single_key_, tokens[0].data(), tokens[0].size()); // we copy the key
@ -147,11 +165,12 @@ auto MemcacheParser::Parse(string_view str, uint32_t* consumed, Command* cmd) ->
} }
if (num_tokens == 1) { if (num_tokens == 1) {
if (cmd->type == MemcacheParser::STATS) if (base::_in(cmd->type, {MP::STATS, MP::FLUSHALL, MP::QUIT}))
return MemcacheParser::OK; return MP::OK;
return MP::PARSE_ERROR;
} }
return ParseRetrieve(tokens + 1, num_tokens - 1, cmd); return ParseValueless(tokens + 1, num_tokens - 1, cmd);
}; };
} // namespace dfly } // namespace dfly

View File

@ -30,10 +30,13 @@ class MemcacheParser {
GATS = 13, GATS = 13,
STATS = 14, STATS = 14,
// Delete and INCR QUIT = 20,
// The rest of write commands.
DELETE = 21, DELETE = 21,
INCR = 22, INCR = 22,
DECR = 23, DECR = 23,
FLUSHALL = 24,
}; };
// According to https://github.com/memcached/memcached/wiki/Commands#standard-protocol // According to https://github.com/memcached/memcached/wiki/Commands#standard-protocol
@ -42,7 +45,11 @@ class MemcacheParser {
std::string_view key; std::string_view key;
std::vector<std::string_view> keys_ext; std::vector<std::string_view> keys_ext;
uint64_t cas_unique = 0; union {
uint64_t cas_unique = 0; // for CAS COMMAND
uint64_t delta; // for DECR/INCR commands.
};
uint32_t expire_ts = 0; // relative time in seconds. uint32_t expire_ts = 0; // relative time in seconds.
uint32_t bytes_len = 0; uint32_t bytes_len = 0;
uint32_t flags = 0; uint32_t flags = 0;
@ -55,6 +62,7 @@ class MemcacheParser {
UNKNOWN_CMD, UNKNOWN_CMD,
BAD_INT, BAD_INT,
PARSE_ERROR, PARSE_ERROR,
BAD_DELTA,
}; };
static bool IsStoreCmd(CmdType type) { static bool IsStoreCmd(CmdType type) {

View File

@ -34,6 +34,31 @@ TEST_F(MCParserTest, Basic) {
EXPECT_EQ(1, cmd_.flags); EXPECT_EQ(1, cmd_.flags);
EXPECT_EQ(20, cmd_.expire_ts); EXPECT_EQ(20, cmd_.expire_ts);
EXPECT_EQ(3, cmd_.bytes_len); EXPECT_EQ(3, cmd_.bytes_len);
EXPECT_EQ(MemcacheParser::SET, cmd_.type);
st = parser_.Parse("quit\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ(MemcacheParser::QUIT, cmd_.type);
}
TEST_F(MCParserTest, Incr) {
MemcacheParser::Result st = parser_.Parse("incr a\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::PARSE_ERROR, st);
st = parser_.Parse("incr a 1\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ(MemcacheParser::INCR, cmd_.type);
EXPECT_EQ("a", cmd_.key);
EXPECT_EQ(1, cmd_.delta);
EXPECT_FALSE(cmd_.no_reply);
st = parser_.Parse("incr a -1\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::BAD_DELTA, st);
st = parser_.Parse("decr b 10 noreply\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ(MemcacheParser::DECR, cmd_.type);
EXPECT_EQ(10, cmd_.delta);
} }
TEST_F(MCParserTest, Stats) { TEST_F(MCParserTest, Stats) {

View File

@ -88,10 +88,8 @@ void MCReplyBuilder::SendStored() {
SendDirect("STORED\r\n"); SendDirect("STORED\r\n");
} }
void MCReplyBuilder::SendGetReply(std::string_view key, uint32_t flags, std::string_view value) { void MCReplyBuilder::SendLong(long val) {
string first = absl::StrCat("VALUE ", key, " ", flags, " ", value.size(), "\r\n"); SendDirect(absl::StrCat(val, kCRLF));
iovec v[] = {IoVec(first), IoVec(value), IoVec(kCRLF)};
Send(v, ABSL_ARRAYSIZE(v));
} }
void MCReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) { void MCReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) {
@ -140,10 +138,6 @@ void RedisReplyBuilder::SendError(string_view str) {
} }
} }
void RedisReplyBuilder::SendGetReply(std::string_view key, uint32_t flags, std::string_view value) {
SendBulkString(value);
}
void RedisReplyBuilder::SendStored() { void RedisReplyBuilder::SendStored() {
SendSimpleString("OK"); SendSimpleString("OK");
} }

View File

@ -22,8 +22,6 @@ class ReplyBuilderInterface {
virtual std::error_code GetError() const = 0; virtual std::error_code GetError() const = 0;
virtual void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) = 0;
struct ResponseValue { struct ResponseValue {
std::string_view key; std::string_view key;
std::string value; std::string value;
@ -34,6 +32,7 @@ class ReplyBuilderInterface {
using OptResp = std::optional<ResponseValue>; using OptResp = std::optional<ResponseValue>;
virtual void SendMGetResponse(const OptResp* resp, uint32_t count) = 0; virtual void SendMGetResponse(const OptResp* resp, uint32_t count) = 0;
virtual void SendLong(long val) = 0;
virtual void SendSetSkipped() = 0; virtual void SendSetSkipped() = 0;
}; };
@ -92,11 +91,11 @@ class MCReplyBuilder : public SinkReplyBuilder {
MCReplyBuilder(::io::Sink* stream); MCReplyBuilder(::io::Sink* stream);
void SendError(std::string_view str) final; void SendError(std::string_view str) final;
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) final; // void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) final;
void SendMGetResponse(const OptResp* resp, uint32_t count) final; void SendMGetResponse(const OptResp* resp, uint32_t count) final;
void SendStored() final; void SendStored() final;
void SendLong(long val) final;
void SendSetSkipped() final; void SendSetSkipped() final;
void SendClientError(std::string_view str); void SendClientError(std::string_view str);
@ -111,10 +110,10 @@ class RedisReplyBuilder : public SinkReplyBuilder {
} }
void SendError(std::string_view str) override; void SendError(std::string_view str) override;
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) override;
void SendMGetResponse(const OptResp* resp, uint32_t count) override; void SendMGetResponse(const OptResp* resp, uint32_t count) override;
void SendStored() override; void SendStored() override;
void SendLong(long val) override;
void SendSetSkipped() override; void SendSetSkipped() override;
void SendError(OpStatus status); void SendError(OpStatus status);
@ -126,7 +125,6 @@ class RedisReplyBuilder : public SinkReplyBuilder {
virtual void SendStringArr(absl::Span<const std::string_view> arr); virtual void SendStringArr(absl::Span<const std::string_view> arr);
virtual void SendNull(); virtual void SendNull();
virtual void SendLong(long val);
virtual void SendDouble(double val); virtual void SendDouble(double val);
virtual void SendBulkString(std::string_view str); virtual void SendBulkString(std::string_view str);

View File

@ -162,7 +162,6 @@ void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) {
get_qps.Inc(); get_qps.Inc();
std::string_view key = ArgS(args, 1); std::string_view key = ArgS(args, 1);
uint32_t mc_flag = 0;
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<string> { auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<string> {
OpResult<MainIterator> it_res = shard->db_slice().Find(t->db_index(), key, OBJ_STRING); OpResult<MainIterator> it_res = shard->db_slice().Find(t->db_index(), key, OBJ_STRING);
@ -180,7 +179,7 @@ void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) {
if (result) { if (result) {
DVLOG(1) << "GET " << trans->DebugId() << ": " << key << " " << result.value(); DVLOG(1) << "GET " << trans->DebugId() << ": " << key << " " << result.value();
(*cntx)->SendGetReply(key, mc_flag, result.value()); (*cntx)->SendBulkString(*result);
} else { } else {
switch (result.status()) { switch (result.status()) {
case OpStatus::WRONG_TYPE: case OpStatus::WRONG_TYPE:
@ -215,7 +214,7 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) {
} }
if (prev_val) { if (prev_val) {
(*cntx)->SendGetReply(key, 0, *prev_val); (*cntx)->SendBulkString(*prev_val);
return; return;
} }
return (*cntx)->SendNull(); return (*cntx)->SendNull();
@ -270,15 +269,18 @@ void StringFamily::IncrByGeneric(std::string_view key, int64_t val, ConnectionCo
}; };
OpResult<int64_t> result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); OpResult<int64_t> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
auto* builder = cntx->reply_builder();
DVLOG(2) << "IncrByGeneric " << key << "/" << result.value(); DVLOG(2) << "IncrByGeneric " << key << "/" << result.value();
switch (result.status()) { switch (result.status()) {
case OpStatus::OK: case OpStatus::OK:
return (*cntx)->SendLong(result.value()); return builder->SendLong(result.value());
case OpStatus::INVALID_VALUE: case OpStatus::INVALID_VALUE:
return (*cntx)->SendError(kInvalidIntErr); return builder->SendError(kInvalidIntErr);
case OpStatus::OUT_OF_RANGE: case OpStatus::OUT_OF_RANGE:
return (*cntx)->SendError("increment or decrement would overflow"); return builder->SendError("increment or decrement would overflow");
case OpStatus::WRONG_TYPE:
return builder->SendError(kWrongTypeErr);
default:; default:;
} }
__builtin_unreachable(); __builtin_unreachable();