Add support for incr/decr/quit memcache commands

This commit is contained in:
Roman Gershman 2022-02-23 23:54:56 +02:00
parent fcb58efe15
commit a93940913b
9 changed files with 142 additions and 79 deletions

View File

@ -132,10 +132,10 @@ a distributed log format.
- [x] prepend
- [ ] delete
- [ ] flush_all
- [ ] incr
- [ ] decr
- [x] incr
- [x] decr
- [ ] version
- [ ] quit
- [x] quit
API 2.0
- [ ] List Family

View File

@ -392,6 +392,8 @@ auto Connection::ParseMemcache() -> ParserStatus {
if (result == MemcacheParser::PARSE_ERROR) {
builder->SendError(""); // ERROR.
} else if (result == MemcacheParser::BAD_DELTA) {
builder->SendClientError("invalid numeric delta argument");
} else if (result != MemcacheParser::OK) {
builder->SendClientError("bad command line format");
}

View File

@ -56,7 +56,6 @@ class InterpreterReplier : public RedisReplyBuilder {
}
void SendError(std::string_view str) override;
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) override;
void SendStored() override;
void SendSimpleString(std::string_view str) final;
@ -154,11 +153,6 @@ void InterpreterReplier::SendError(string_view str) {
explr_->OnError(str);
}
void InterpreterReplier::SendGetReply(string_view key, uint32_t flags, string_view value) {
DCHECK(array_len_.empty());
explr_->OnString(value);
}
void InterpreterReplier::SendStored() {
DCHECK(array_len_.empty());
SendSimpleString("OK");
@ -477,19 +471,28 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
ConnectionContext* cntx) {
absl::InlinedVector<MutableSlice, 8> args;
char cmd_name[16];
char set_opt[4] = {0};
char store_opt[32] = {0};
MCReplyBuilder* mc_builder = static_cast<MCReplyBuilder*>(cntx->reply_builder());
switch (cmd.type) {
case MemcacheParser::REPLACE:
strcpy(cmd_name, "SET");
strcpy(set_opt, "XX");
strcpy(store_opt, "XX");
break;
case MemcacheParser::SET:
strcpy(cmd_name, "SET");
break;
case MemcacheParser::ADD:
strcpy(cmd_name, "SET");
strcpy(set_opt, "NX");
strcpy(store_opt, "NX");
break;
case MemcacheParser::INCR:
strcpy(cmd_name, "INCRBY");
absl::numbers_internal::FastIntToBuffer(cmd.delta, store_opt);
break;
case MemcacheParser::DECR:
strcpy(cmd_name, "DECRBY");
absl::numbers_internal::FastIntToBuffer(cmd.delta, store_opt);
break;
case MemcacheParser::APPEND:
strcpy(cmd_name, "APPEND");
@ -500,33 +503,42 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
case MemcacheParser::GET:
strcpy(cmd_name, "MGET");
break;
case MemcacheParser::QUIT:
strcpy(cmd_name, "QUIT");
break;
case MemcacheParser::STATS:
server_family_.StatsMC(cmd.key, cntx);
return;
default:
mc_builder->SendClientError("bad command line format");
return;
}
args.emplace_back(cmd_name, strlen(cmd_name));
char* key = const_cast<char*>(cmd.key.data());
args.emplace_back(key, cmd.key.size());
if (!cmd.key.empty()) {
char* key = const_cast<char*>(cmd.key.data());
args.emplace_back(key, cmd.key.size());
}
if (MemcacheParser::IsStoreCmd(cmd.type)) {
char* v = const_cast<char*>(value.data());
args.emplace_back(v, value.size());
if (set_opt[0]) {
args.emplace_back(set_opt, strlen(set_opt));
if (store_opt[0]) {
args.emplace_back(store_opt, strlen(store_opt));
}
cntx->conn_state.memcache_flag = cmd.flags;
} else {
} else if (cmd.type < MemcacheParser::QUIT) { // read commands
for (auto s : cmd.keys_ext) {
char* key = const_cast<char*>(s.data());
args.emplace_back(key, s.size());
}
} else { // write commands.
if (store_opt[0]) {
args.emplace_back(store_opt, strlen(store_opt));
}
}
DispatchCommand(CmdArgList{args}, cntx);
@ -567,8 +579,11 @@ void Service::RegisterHttp(HttpListenerBase* listener) {
}
void Service::Quit(CmdArgList args, ConnectionContext* cntx) {
(*cntx)->SendOk();
(*cntx)->CloseConnection();
if (cntx->protocol() == Protocol::REDIS)
(*cntx)->SendOk();
SinkReplyBuilder* builder = static_cast<SinkReplyBuilder*>(cntx->reply_builder());
builder->CloseConnection();
}
void Service::Multi(CmdArgList args, ConnectionContext* cntx) {

View File

@ -1,49 +1,52 @@
// Copyright 2021, Roman Gershman. All rights reserved.
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#include "server/memcache_parser.h"
#include <absl/container/flat_hash_map.h>
#include <absl/strings/ascii.h>
#include <absl/strings/numbers.h>
#include "base/stl_util.h"
namespace dfly {
using namespace std;
using MP = MemcacheParser;
namespace {
pair<string_view, MemcacheParser::CmdType> cmd_map[] = {
{"set", MemcacheParser::SET}, {"add", MemcacheParser::ADD},
{"replace", MemcacheParser::REPLACE}, {"append", MemcacheParser::APPEND},
{"prepend", MemcacheParser::PREPEND}, {"cas", MemcacheParser::CAS},
{"get", MemcacheParser::GET}, {"gets", MemcacheParser::GETS},
{"gat", MemcacheParser::GAT}, {"gats", MemcacheParser::GATS},
{"stats", MemcacheParser::STATS},
};
MP::CmdType From(string_view token) {
static absl::flat_hash_map<string_view, MP::CmdType> cmd_map{
{"set", MP::SET}, {"add", MP::ADD}, {"replace", MP::REPLACE},
{"append", MP::APPEND}, {"prepend", MP::PREPEND}, {"cas", MP::CAS},
{"get", MP::GET}, {"gets", MP::GETS}, {"gat", MP::GAT},
{"gats", MP::GATS}, {"stats", MP::STATS}, {"incr", MP::INCR},
{"decr", MP::DECR}, {"delete", MP::DELETE}, {"flush_all", MP::FLUSHALL},
{"quit", MP::QUIT},
};
MemcacheParser::CmdType From(string_view token) {
for (const auto& k_v : cmd_map) {
if (token == k_v.first)
return k_v.second;
}
return MemcacheParser::INVALID;
auto it = cmd_map.find(token);
if (it == cmd_map.end())
return MP::INVALID;
return it->second;
}
MemcacheParser::Result ParseStore(const std::string_view* tokens, unsigned num_tokens,
MemcacheParser::Command* res) {
MP::Result ParseStore(const std::string_view* tokens, unsigned num_tokens, MP::Command* res) {
unsigned opt_pos = 3;
if (res->type == MemcacheParser::CAS) {
if (res->type == MP::CAS) {
if (num_tokens <= opt_pos)
return MemcacheParser::PARSE_ERROR;
return MP::PARSE_ERROR;
++opt_pos;
}
uint32_t flags;
if (!absl::SimpleAtoi(tokens[0], &flags) || !absl::SimpleAtoi(tokens[1], &res->expire_ts) ||
!absl::SimpleAtoi(tokens[2], &res->bytes_len))
return MemcacheParser::BAD_INT;
return MP::BAD_INT;
if (res->type == MemcacheParser::CAS && !absl::SimpleAtoi(tokens[3], &res->cas_unique)) {
return MemcacheParser::BAD_INT;
if (res->type == MP::CAS && !absl::SimpleAtoi(tokens[3], &res->cas_unique)) {
return MP::BAD_INT;
}
res->flags = flags;
@ -51,39 +54,54 @@ MemcacheParser::Result ParseStore(const std::string_view* tokens, unsigned num_t
if (tokens[opt_pos] == "noreply") {
res->no_reply = true;
} else {
return MemcacheParser::PARSE_ERROR;
return MP::PARSE_ERROR;
}
} else if (num_tokens > opt_pos + 1) {
return MemcacheParser::PARSE_ERROR;
return MP::PARSE_ERROR;
}
return MemcacheParser::OK;
return MP::OK;
}
MemcacheParser::Result ParseRetrieve(const std::string_view* tokens, unsigned num_tokens,
MemcacheParser::Command* res) {
MP::Result ParseValueless(const std::string_view* tokens, unsigned num_tokens, MP::Command* res) {
unsigned key_pos = 0;
if (res->type == MemcacheParser::GAT || res->type == MemcacheParser::GATS) {
if (res->type == MP::GAT || res->type == MP::GATS) {
if (!absl::SimpleAtoi(tokens[0], &res->expire_ts)) {
return MemcacheParser::BAD_INT;
return MP::BAD_INT;
}
++key_pos;
}
res->key = tokens[key_pos++];
if (res->type == MemcacheParser::STATS && key_pos < num_tokens)
return MemcacheParser::PARSE_ERROR;
if (key_pos < num_tokens && base::_in(res->type, {MP::STATS, MP::FLUSHALL}))
return MP::PARSE_ERROR; // we do not support additional arguments for now.
if (res->type == MP::INCR || res->type == MP::DECR) {
if (key_pos == num_tokens)
return MP::PARSE_ERROR;
if (!absl::SimpleAtoi(tokens[key_pos], &res->delta))
return MP::BAD_DELTA;
++key_pos;
}
while (key_pos < num_tokens) {
res->keys_ext.push_back(tokens[key_pos++]);
}
return MemcacheParser::OK;
if (res->type >= MP::DELETE) { // write commands
if (!res->keys_ext.empty() && res->keys_ext.back() == "noreply") {
res->no_reply = true;
res->keys_ext.pop_back();
}
}
return MP::OK;
}
} // namespace
auto MemcacheParser::Parse(string_view str, uint32_t* consumed, Command* cmd) -> Result {
auto MP::Parse(string_view str, uint32_t* consumed, Command* cmd) -> Result {
auto pos = str.find('\n');
*consumed = 0;
if (pos == string_view::npos) {
@ -137,7 +155,7 @@ auto MemcacheParser::Parse(string_view str, uint32_t* consumed, Command* cmd) ->
if (cmd->type <= CAS) { // Store command
if (num_tokens < 5 || tokens[1].size() > 250) {
return MemcacheParser::PARSE_ERROR;
return MP::PARSE_ERROR;
}
// memcpy(single_key_, tokens[0].data(), tokens[0].size()); // we copy the key
@ -147,11 +165,12 @@ auto MemcacheParser::Parse(string_view str, uint32_t* consumed, Command* cmd) ->
}
if (num_tokens == 1) {
if (cmd->type == MemcacheParser::STATS)
return MemcacheParser::OK;
if (base::_in(cmd->type, {MP::STATS, MP::FLUSHALL, MP::QUIT}))
return MP::OK;
return MP::PARSE_ERROR;
}
return ParseRetrieve(tokens + 1, num_tokens - 1, cmd);
return ParseValueless(tokens + 1, num_tokens - 1, cmd);
};
} // namespace dfly

View File

@ -30,10 +30,13 @@ class MemcacheParser {
GATS = 13,
STATS = 14,
// Delete and INCR
QUIT = 20,
// The rest of write commands.
DELETE = 21,
INCR = 22,
DECR = 23,
FLUSHALL = 24,
};
// According to https://github.com/memcached/memcached/wiki/Commands#standard-protocol
@ -42,7 +45,11 @@ class MemcacheParser {
std::string_view key;
std::vector<std::string_view> keys_ext;
uint64_t cas_unique = 0;
union {
uint64_t cas_unique = 0; // for CAS COMMAND
uint64_t delta; // for DECR/INCR commands.
};
uint32_t expire_ts = 0; // relative time in seconds.
uint32_t bytes_len = 0;
uint32_t flags = 0;
@ -55,6 +62,7 @@ class MemcacheParser {
UNKNOWN_CMD,
BAD_INT,
PARSE_ERROR,
BAD_DELTA,
};
static bool IsStoreCmd(CmdType type) {

View File

@ -34,6 +34,31 @@ TEST_F(MCParserTest, Basic) {
EXPECT_EQ(1, cmd_.flags);
EXPECT_EQ(20, cmd_.expire_ts);
EXPECT_EQ(3, cmd_.bytes_len);
EXPECT_EQ(MemcacheParser::SET, cmd_.type);
st = parser_.Parse("quit\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ(MemcacheParser::QUIT, cmd_.type);
}
TEST_F(MCParserTest, Incr) {
MemcacheParser::Result st = parser_.Parse("incr a\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::PARSE_ERROR, st);
st = parser_.Parse("incr a 1\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ(MemcacheParser::INCR, cmd_.type);
EXPECT_EQ("a", cmd_.key);
EXPECT_EQ(1, cmd_.delta);
EXPECT_FALSE(cmd_.no_reply);
st = parser_.Parse("incr a -1\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::BAD_DELTA, st);
st = parser_.Parse("decr b 10 noreply\r\n", &consumed_, &cmd_);
EXPECT_EQ(MemcacheParser::OK, st);
EXPECT_EQ(MemcacheParser::DECR, cmd_.type);
EXPECT_EQ(10, cmd_.delta);
}
TEST_F(MCParserTest, Stats) {

View File

@ -88,10 +88,8 @@ void MCReplyBuilder::SendStored() {
SendDirect("STORED\r\n");
}
void MCReplyBuilder::SendGetReply(std::string_view key, uint32_t flags, std::string_view value) {
string first = absl::StrCat("VALUE ", key, " ", flags, " ", value.size(), "\r\n");
iovec v[] = {IoVec(first), IoVec(value), IoVec(kCRLF)};
Send(v, ABSL_ARRAYSIZE(v));
void MCReplyBuilder::SendLong(long val) {
SendDirect(absl::StrCat(val, kCRLF));
}
void MCReplyBuilder::SendMGetResponse(const OptResp* resp, uint32_t count) {
@ -140,10 +138,6 @@ void RedisReplyBuilder::SendError(string_view str) {
}
}
void RedisReplyBuilder::SendGetReply(std::string_view key, uint32_t flags, std::string_view value) {
SendBulkString(value);
}
void RedisReplyBuilder::SendStored() {
SendSimpleString("OK");
}

View File

@ -22,8 +22,6 @@ class ReplyBuilderInterface {
virtual std::error_code GetError() const = 0;
virtual void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) = 0;
struct ResponseValue {
std::string_view key;
std::string value;
@ -34,6 +32,7 @@ class ReplyBuilderInterface {
using OptResp = std::optional<ResponseValue>;
virtual void SendMGetResponse(const OptResp* resp, uint32_t count) = 0;
virtual void SendLong(long val) = 0;
virtual void SendSetSkipped() = 0;
};
@ -92,11 +91,11 @@ class MCReplyBuilder : public SinkReplyBuilder {
MCReplyBuilder(::io::Sink* stream);
void SendError(std::string_view str) final;
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) final;
// void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) final;
void SendMGetResponse(const OptResp* resp, uint32_t count) final;
void SendStored() final;
void SendLong(long val) final;
void SendSetSkipped() final;
void SendClientError(std::string_view str);
@ -111,10 +110,10 @@ class RedisReplyBuilder : public SinkReplyBuilder {
}
void SendError(std::string_view str) override;
void SendGetReply(std::string_view key, uint32_t flags, std::string_view value) override;
void SendMGetResponse(const OptResp* resp, uint32_t count) override;
void SendStored() override;
void SendLong(long val) override;
void SendSetSkipped() override;
void SendError(OpStatus status);
@ -126,7 +125,6 @@ class RedisReplyBuilder : public SinkReplyBuilder {
virtual void SendStringArr(absl::Span<const std::string_view> arr);
virtual void SendNull();
virtual void SendLong(long val);
virtual void SendDouble(double val);
virtual void SendBulkString(std::string_view str);

View File

@ -162,7 +162,6 @@ void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) {
get_qps.Inc();
std::string_view key = ArgS(args, 1);
uint32_t mc_flag = 0;
auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<string> {
OpResult<MainIterator> it_res = shard->db_slice().Find(t->db_index(), key, OBJ_STRING);
@ -180,7 +179,7 @@ void StringFamily::Get(CmdArgList args, ConnectionContext* cntx) {
if (result) {
DVLOG(1) << "GET " << trans->DebugId() << ": " << key << " " << result.value();
(*cntx)->SendGetReply(key, mc_flag, result.value());
(*cntx)->SendBulkString(*result);
} else {
switch (result.status()) {
case OpStatus::WRONG_TYPE:
@ -215,7 +214,7 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) {
}
if (prev_val) {
(*cntx)->SendGetReply(key, 0, *prev_val);
(*cntx)->SendBulkString(*prev_val);
return;
}
return (*cntx)->SendNull();
@ -270,15 +269,18 @@ void StringFamily::IncrByGeneric(std::string_view key, int64_t val, ConnectionCo
};
OpResult<int64_t> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
auto* builder = cntx->reply_builder();
DVLOG(2) << "IncrByGeneric " << key << "/" << result.value();
switch (result.status()) {
case OpStatus::OK:
return (*cntx)->SendLong(result.value());
return builder->SendLong(result.value());
case OpStatus::INVALID_VALUE:
return (*cntx)->SendError(kInvalidIntErr);
return builder->SendError(kInvalidIntErr);
case OpStatus::OUT_OF_RANGE:
return (*cntx)->SendError("increment or decrement would overflow");
return builder->SendError("increment or decrement would overflow");
case OpStatus::WRONG_TYPE:
return builder->SendError(kWrongTypeErr);
default:;
}
__builtin_unreachable();