diff --git a/core/op_status.h b/core/op_status.h index c44c97c..6f3c70f 100644 --- a/core/op_status.h +++ b/core/op_status.h @@ -13,6 +13,8 @@ enum class OpStatus : uint16_t { KEY_EXISTS, KEY_NOTFOUND, SKIPPED, + INVALID_VALUE, + OUT_OF_RANGE, WRONG_TYPE, TIMED_OUT, }; diff --git a/server/string_family.cc b/server/string_family.cc index eaf3641..ffe8af6 100644 --- a/server/string_family.cc +++ b/server/string_family.cc @@ -30,6 +30,8 @@ DEFINE_VARZ(VarzQps, get_qps); } // namespace + + SetCmd::SetCmd(DbSlice* db_slice) : db_slice_(db_slice) { } @@ -89,7 +91,7 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) { std::string_view value = ArgS(args, 2); VLOG(2) << "Set " << key << " " << value; - SetCmd::SetParams sparams{cntx->db_index()}; // TODO: db_index. + SetCmd::SetParams sparams{cntx->db_index()}; int64_t int_arg; for (size_t i = 3; i < args.size(); ++i) { @@ -205,6 +207,61 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) { return cntx->SendNull(); } +void StringFamily::Incr(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + return IncrByGeneric(key, 1, cntx); +} + +void StringFamily::IncrBy(CmdArgList args, ConnectionContext* cntx) { + DCHECK_EQ(3u, args.size()); + + std::string_view key = ArgS(args, 1); + std::string_view sval = ArgS(args, 2); + int64_t val; + + if (!absl::SimpleAtoi(sval, &val)) { + return cntx->SendError(kInvalidIntErr); + } + return IncrByGeneric(key, val, cntx); +} + +void StringFamily::Decr(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + return IncrByGeneric(key, -1, cntx); +} + +void StringFamily::DecrBy(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + std::string_view sval = ArgS(args, 2); + int64_t val; + + if (!absl::SimpleAtoi(sval, &val)) { + return cntx->SendError(kInvalidIntErr); + } + return IncrByGeneric(key, -val, cntx); +} + +void StringFamily::IncrByGeneric(std::string_view key, int64_t val, ConnectionContext* cntx) { + auto cb = [&](Transaction* t, EngineShard* shard) { + OpResult res = OpIncrBy(OpArgs{shard, t->db_index()}, key, val); + return res; + }; + + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + + DVLOG(2) << "IncrByGeneric " << key << "/" << result.value(); + switch (result.status()) { + case OpStatus::OK: + return cntx->SendLong(result.value()); + case OpStatus::INVALID_VALUE: + return cntx->SendError(kInvalidIntErr); + case OpStatus::OUT_OF_RANGE: + return cntx->SendError("increment or decrement would overflow"); + default:; + } + __builtin_unreachable(); +} + void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) { DCHECK_GT(args.size(), 1U); @@ -293,6 +350,41 @@ OpStatus StringFamily::OpMSet(const Transaction* t, EngineShard* es) { return OpStatus::OK; } +OpResult StringFamily::OpIncrBy(const OpArgs& op_args, std::string_view key, + int64_t incr) { + auto& db_slice = op_args.shard->db_slice(); + auto [it, expire_it] = db_slice.FindExt(op_args.db_ind, key); + + if (!IsValid(it)) { + CompactObj cobj; + cobj.SetInt(incr); + + db_slice.AddNew(op_args.db_ind, key, std::move(cobj), 0); + return incr; + } + + if (it->second.ObjType() != OBJ_STRING) { + return OpStatus::WRONG_TYPE; + } + + auto opt_prev = it->second.TryGetInt(); + if (!opt_prev) { + return OpStatus::INVALID_VALUE; + } + + long long prev = *opt_prev; + if ((incr < 0 && prev < 0 && incr < (LLONG_MIN - prev)) || + (incr > 0 && prev > 0 && incr > (LLONG_MAX - prev))) { + return OpStatus::OUT_OF_RANGE; + } + + int64_t new_val = prev + incr; + it->second.SetInt(new_val); + + return new_val; +} + + void StringFamily::Init(util::ProactorPool* pp) { set_qps.Init(pp); get_qps.Init(pp); @@ -307,6 +399,10 @@ void StringFamily::Shutdown() { void StringFamily::Register(CommandRegistry* registry) { *registry << CI{"SET", CO::WRITE | CO::DENYOOM, -3, 1, 1, 1}.HFUNC(Set) + << CI{"INCR", CO::WRITE | CO::DENYOOM | CO::FAST, 2, 1, 1, 1}.HFUNC(Incr) + << CI{"DECR", CO::WRITE | CO::DENYOOM | CO::FAST, 2, 1, 1, 1}.HFUNC(Decr) + << CI{"INCRBY", CO::WRITE | CO::DENYOOM | CO::FAST, 3, 1, 1, 1}.HFUNC(IncrBy) + << CI{"DECRBY", CO::WRITE | CO::DENYOOM | CO::FAST, 3, 1, 1, 1}.HFUNC(DecrBy) << CI{"GET", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(Get) << CI{"GETSET", CO::WRITE | CO::DENYOOM | CO::FAST, 3, 1, 1, 1}.HFUNC(GetSet) << CI{"MGET", CO::READONLY | CO::FAST, -2, 1, -1, 1}.HFUNC(MGet) diff --git a/server/string_family.h b/server/string_family.h index eaba08c..c6a2f6f 100644 --- a/server/string_family.h +++ b/server/string_family.h @@ -55,12 +55,19 @@ class StringFamily { static void GetSet(CmdArgList args, ConnectionContext* cntx); static void MGet(CmdArgList args, ConnectionContext* cntx); static void MSet(CmdArgList args, ConnectionContext* cntx); + static void Incr(CmdArgList args, ConnectionContext* cntx); + static void IncrBy(CmdArgList args, ConnectionContext* cntx); + static void Decr(CmdArgList args, ConnectionContext* cntx); + static void DecrBy(CmdArgList args, ConnectionContext* cntx); + + static void IncrByGeneric(std::string_view key, int64_t val, ConnectionContext* cntx); using MGetResponse = std::vector>; static MGetResponse OpMGet(const Transaction* t, EngineShard* shard); static OpStatus OpMSet(const Transaction* t, EngineShard* es); + static OpResult OpIncrBy(const OpArgs& op_args, std::string_view key, int64_t val); }; } // namespace dfly diff --git a/server/string_family_test.cc b/server/string_family_test.cc index 13d8e4d..b591aea 100644 --- a/server/string_family_test.cc +++ b/server/string_family_test.cc @@ -36,6 +36,22 @@ TEST_F(StringFamilyTest, SetGet) { EXPECT_THAT(Run({"get", "key"}), RespEq("2")); } +TEST_F(StringFamilyTest, Incr) { + ASSERT_THAT(Run({"set", "key", "0"}), RespEq("OK")); + ASSERT_THAT(Run({"incr", "key"}), ElementsAre(IntArg(1))); + + ASSERT_THAT(Run({"set", "key1", "123456789"}), RespEq("OK")); + ASSERT_THAT(Run({"incrby", "key1", "0"}), ElementsAre(IntArg(123456789))); + + ASSERT_THAT(Run({"set", "key1", "-123456789"}), RespEq("OK")); + ASSERT_THAT(Run({"incrby", "key1", "0"}), ElementsAre(IntArg(-123456789))); + + ASSERT_THAT(Run({"set", "key1", " -123 "}), RespEq("OK")); + ASSERT_THAT(Run({"incrby", "key1", "1"}), ElementsAre(ErrArg("ERR value is not an integer"))); + + ASSERT_THAT(Run({"incrby", "ne", "0"}), ElementsAre(IntArg(0))); +} + TEST_F(StringFamilyTest, Expire) { constexpr uint64_t kNow = 232279092000; @@ -50,6 +66,10 @@ TEST_F(StringFamilyTest, Expire) { EXPECT_THAT(Run({"get", "key"}), ElementsAre(ArgType(RespExpr::NIL))); ASSERT_THAT(Run({"set", "i", "1", "PX", "10"}), RespEq("OK")); + ASSERT_THAT(Run({"incr", "i"}), ElementsAre(IntArg(2))); + + UpdateTime(kNow + 30); + ASSERT_THAT(Run({"incr", "i"}), ElementsAre(IntArg(1))); } TEST_F(StringFamilyTest, Set) { @@ -64,6 +84,12 @@ TEST_F(StringFamilyTest, Set) { resp = Run({"set", "foo", "bar", "xx"}); ASSERT_THAT(resp, RespEq("OK")); + resp = Run({"set", "foo", "bar", "ex", "abc"}); + ASSERT_THAT(resp, ElementsAre(ErrArg(kInvalidIntErr))); + + resp = Run({"set", "foo", "bar", "ex", "-1"}); + ASSERT_THAT(resp, ElementsAre(ErrArg("invalid expire time in set"))); + resp = Run({"set", "foo", "bar", "ex", "1"}); ASSERT_THAT(resp, RespEq("OK")); } @@ -94,6 +120,61 @@ TEST_F(StringFamilyTest, MGetSet) { set_fb.join(); } +TEST_F(StringFamilyTest, MSetGet) { + Run({"mset", "x", "0", "y", "0", "a", "0", "b", "0"}); + ASSERT_EQ(2, GetDebugInfo().shards_count); + + Run({"mset", "x", "0", "y", "0"}); + ASSERT_EQ(1, GetDebugInfo().shards_count); + + Run({"mset", "x", "1", "b", "5", "x", "0"}); + ASSERT_EQ(2, GetDebugInfo().shards_count); + + int64_t val = CheckedInt({"get", "x"}); + EXPECT_EQ(0, val); + + val = CheckedInt({"get", "b"}); + EXPECT_EQ(5, val); + + auto mset_fb = pp_->at(0)->LaunchFiber([&] { + for (size_t i = 0; i < 1000; ++i) { + RespVec resp = Run({"mset", "x", StrCat(i), "b", StrCat(i)}); + ASSERT_THAT(resp, RespEq("OK")) << i; + } + }); + + // A problematic order when mset is not atomic: set x, get x, get b (old), set b + auto get_fb = pp_->at(2)->LaunchFiber([&] { + for (size_t i = 0; i < 1000; ++i) { + int64_t x = CheckedInt({"get", "x"}); + int64_t z = CheckedInt({"get", "b"}); + + ASSERT_LE(x, z) << "Inconsistency at " << i; + } + }); + + mset_fb.join(); + get_fb.join(); +} + + +TEST_F(StringFamilyTest, MSetDel) { + auto mset_fb = pp_->at(0)->LaunchFiber([&] { + for (size_t i = 0; i < 1000; ++i) { + Run({"mset", "x", "0", "z", "0"}); + } + }); + + auto del_fb = pp_->at(2)->LaunchFiber([&] { + for (size_t i = 0; i < 1000; ++i) { + CheckedInt({"del", "x", "z"}); + } + }); + + mset_fb.join(); + del_fb.join(); +} + TEST_F(StringFamilyTest, IntKey) { Run({"mset", "1", "1", "-1000", "-1000"}); auto resp = Run({"get", "1"}); @@ -127,4 +208,46 @@ TEST_F(StringFamilyTest, SingleShard) { mget_fb.join(); } +TEST_F(StringFamilyTest, MSetIncr) { + /* serialzable orders + init: x=z=0 + + mset x=z=1 + mset, incr x, incr z = 2, 2 + incr x, mset, incr z = 1, 2 + incr x, incr z, mset = 1, 1 +*/ + + /* unserializable scenario when mset is not atomic with respect to incr x + set x, incr x, incr z, set z = 2, 1 + */ + + Run({"mset", "a", "0", "b", "0", "c", "0"}); + ASSERT_EQ(2, GetDebugInfo("IO0").shards_count); + + auto mset_fb = pp_->at(0)->LaunchFiber([&] { + for (size_t i = 1; i < 1000; ++i) { + string base = StrCat(i * 900); + auto resp = Run({"mset", "b", base, "a", base, "c", base}); + ASSERT_THAT(resp, RespEq("OK")); + } + }); + + auto get_fb = pp_->at(1)->LaunchFiber([&] { + for (unsigned j = 0; j < 900; ++j) { + int64_t a = CheckedInt({"incr", "a"}); + int64_t b = CheckedInt({"incr", "b"}); + ASSERT_LE(a, b); + + int64_t c = CheckedInt({"incr", "c"}); + if (a > c) { + LOG(ERROR) << "Consistency error "; + } + ASSERT_LE(a, c); + } + }); + mset_fb.join(); + get_fb.join(); +} + } // namespace dfly