diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json index d9e38de..ed591a1 100644 --- a/.vscode/c_cpp_properties.json +++ b/.vscode/c_cpp_properties.json @@ -8,7 +8,8 @@ "cStandard": "c17", "cppStandard": "c++17", "intelliSenseMode": "${default}", - "compileCommands": "${workspaceFolder}/build-dbg/compile_commands.json" + "compileCommands": "${workspaceFolder}/build-dbg/compile_commands.json", + "configurationProvider": "ms-vscode.cmake-tools" } ], "version": 4 diff --git a/src/server/blocking_controller_test.cc b/src/server/blocking_controller_test.cc index 3f56ec6..1fe04c4 100644 --- a/src/server/blocking_controller_test.cc +++ b/src/server/blocking_controller_test.cc @@ -27,7 +27,6 @@ class BlockingControllerTest : public Test { void TearDown() override; std::unique_ptr pp_; - std::unique_ptr ess_; boost::intrusive_ptr trans_; CommandId cid_; StringVec str_vec_; @@ -39,14 +38,10 @@ constexpr size_t kNumThreads = 3; void BlockingControllerTest::SetUp() { pp_.reset(new uring::UringPool(16, kNumThreads)); pp_->Run(); - ess_.reset(new EngineShardSet(pp_.get())); - ess_->Init(kNumThreads); + shard_set = new EngineShardSet(pp_.get()); + shard_set->Init(kNumThreads, false); - auto cb = [&](uint32_t index, ProactorBase* pb) { ess_->InitThreadLocal(pb, false); }; - - pp_->AwaitFiberOnAll(cb); - - trans_.reset(new Transaction{&cid_, ess_.get()}); + trans_.reset(new Transaction{&cid_}); str_vec_.assign({"blpop", "x", "z", "0"}); for (auto& s : str_vec_) { @@ -54,22 +49,23 @@ void BlockingControllerTest::SetUp() { } trans_->InitByArgs(0, {arg_vec_.data(), arg_vec_.size()}); - CHECK_EQ(0u, Shard("x", ess_->size())); - CHECK_EQ(2u, Shard("z", ess_->size())); + CHECK_EQ(0u, Shard("x", shard_set->size())); + CHECK_EQ(2u, Shard("z", shard_set->size())); const TestInfo* const test_info = UnitTest::GetInstance()->current_test_info(); LOG(INFO) << "Starting " << test_info->name(); } void BlockingControllerTest::TearDown() { - ess_->RunBlockingInParallel([](EngineShard*) { EngineShard::DestroyThreadLocal(); }); - ess_.reset(); + shard_set->Shutdown(); + delete shard_set; + pp_->Stop(); pp_.reset(); } TEST_F(BlockingControllerTest, Basic) { - ess_->Await(0, [&] { + shard_set->Await(0, [&] { BlockingController bc(EngineShard::tlocal()); bc.AddWatched(trans_.get()); EXPECT_EQ(1, bc.NumWatched(0)); @@ -87,13 +83,11 @@ TEST_F(BlockingControllerTest, Timeout) { bool res = trans_->WaitOnWatch(tp); EXPECT_FALSE(res); - unsigned num_watched = - ess_->Await(0, [&] { return EngineShard::tlocal()->blocking_controller()->NumWatched(0); }); + unsigned num_watched = shard_set->Await( + 0, [&] { return EngineShard::tlocal()->blocking_controller()->NumWatched(0); }); EXPECT_EQ(0, num_watched); trans_.reset(); - - } } // namespace dfly diff --git a/src/server/conn_context.h b/src/server/conn_context.h index ae0c3fd..e344388 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -78,7 +78,6 @@ class ConnectionContext : public facade::ConnectionContext { // TODO: to introduce proper accessors. Transaction* transaction = nullptr; const CommandId* cid = nullptr; - EngineShardSet* shard_set = nullptr; ConnectionState conn_state; DbIndex db_index() const { diff --git a/src/server/db_slice.h b/src/server/db_slice.h index 9d11f11..5d5565e 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -254,7 +254,8 @@ class DbSlice { time_t expire_base_[2]; // Used for expire logic, represents a real clock. uint64_t version_ = 1; // Used to version entries in the PrimeTable. - int64_t memory_budget_ = INT64_MAX; + ssize_t memory_budget_ = SSIZE_MAX; + mutable SliceEvents events_; // we may change this even for const operations. using LockTable = absl::flat_hash_map; diff --git a/src/server/debugcmd.cc b/src/server/debugcmd.cc index ffe9c75..f2e36a2 100644 --- a/src/server/debugcmd.cc +++ b/src/server/debugcmd.cc @@ -136,13 +136,12 @@ void DebugCmd::Reload(CmdArgList args) { } error_code ec; - EngineShardSet& ess = sf_.service().shard_set(); if (save) { string err_details; const CommandId* cid = sf_.service().FindCmd("SAVE"); CHECK_NOTNULL(cid); - intrusive_ptr trans(new Transaction{cid, &ess}); + intrusive_ptr trans(new Transaction{cid}); trans->InitByArgs(0, {}); VLOG(1) << "Performing save"; ec = sf_.DoSave(trans.get(), &err_details); @@ -157,9 +156,9 @@ void DebugCmd::Reload(CmdArgList args) { } void DebugCmd::Load(std::string_view filename) { - EngineShardSet& ess = sf_.service().shard_set(); + EngineShardSet& ess = *shard_set; const CommandId* cid = sf_.service().FindCmd("FLUSHALL"); - intrusive_ptr flush_trans(new Transaction{cid, &ess}); + intrusive_ptr flush_trans(new Transaction{cid}); flush_trans->InitByArgs(0, {}); VLOG(1) << "Performing flush"; error_code ec = sf_.DoFlush(flush_trans.get(), DbSlice::kDbAll); @@ -251,7 +250,7 @@ void DebugCmd::PopulateRangeFiber(uint64_t from, uint64_t len, std::string_view string key = absl::StrCat(prefix, ":"); size_t prefsize = key.size(); DbIndex db_indx = cntx_->db_index(); - EngineShardSet& ess = sf_.service().shard_set(); + EngineShardSet& ess = *shard_set; std::vector ps(ess.size(), PopulateBatch{db_indx}); SetCmd::SetParams params{db_indx}; @@ -281,7 +280,7 @@ void DebugCmd::PopulateRangeFiber(uint64_t from, uint64_t len, std::string_view } void DebugCmd::Inspect(string_view key) { - EngineShardSet& ess = sf_.service().shard_set(); + EngineShardSet& ess = *shard_set; ShardId sid = Shard(key, ess.size()); auto cb = [&]() -> facade::OpResult { diff --git a/src/server/dragonfly_test.cc b/src/server/dragonfly_test.cc index 2ffe4fc..9e86fb9 100644 --- a/src/server/dragonfly_test.cc +++ b/src/server/dragonfly_test.cc @@ -94,7 +94,7 @@ TEST_F(DflyEngineTest, Multi) { atomic_bool tx_empty = true; - ess_->RunBriefInParallel([&](EngineShard* shard) { + shard_set->RunBriefInParallel([&](EngineShard* shard) { if (!shard->txq()->Empty()) tx_empty.store(false); }); diff --git a/src/server/engine_shard_set.cc b/src/server/engine_shard_set.cc index 09582e1..7a17a38 100644 --- a/src/server/engine_shard_set.cc +++ b/src/server/engine_shard_set.cc @@ -34,6 +34,7 @@ vector cached_stats; // initialized in EngineShard thread_local EngineShard* EngineShard::shard_ = nullptr; constexpr size_t kQueueLen = 64; +EngineShardSet* shard_set = nullptr; EngineShard::Stats& EngineShard::Stats::operator+=(const EngineShard::Stats& o) { ooo_runs += o.ooo_runs; @@ -326,6 +327,11 @@ void EngineShard::CacheStats() { size_t used_mem = UsedMemory(); cached_stats[db_slice_.shard_id()].used_memory.store(used_mem, memory_order_relaxed); + ssize_t free_mem = max_memory_limit - used_mem_current.load(memory_order_relaxed); + if (free_mem < 0) + free_mem = 0; + + db_slice_.SetMemoryBudget(free_mem / shard_set->size()); } size_t EngineShard::UsedMemory() const { @@ -351,10 +357,20 @@ void EngineShard::AddBlocked(Transaction* trans) { */ -void EngineShardSet::Init(uint32_t sz) { +void EngineShardSet::Init(uint32_t sz, bool update_db_time) { CHECK_EQ(0u, size()); cached_stats.resize(sz); shard_queue_.resize(sz); + + pp_->AwaitFiberOnAll([&](uint32_t index, ProactorBase* pb) { + if (index < shard_queue_.size()) { + InitThreadLocal(pb, update_db_time); + } + }); +} + +void EngineShardSet::Shutdown() { + RunBlockingInParallel([](EngineShard*) { EngineShard::DestroyThreadLocal(); }); } void EngineShardSet::InitThreadLocal(ProactorBase* pb, bool update_db_time) { diff --git a/src/server/engine_shard_set.h b/src/server/engine_shard_set.h index 341d5a7..4578c0a 100644 --- a/src/server/engine_shard_set.h +++ b/src/server/engine_shard_set.h @@ -209,8 +209,8 @@ class EngineShardSet { return pp_; } - void Init(uint32_t size); - void InitThreadLocal(util::ProactorBase* pb, bool update_db_time); + void Init(uint32_t size, bool update_db_time); + void Shutdown(); static const std::vector& GetCachedStats(); @@ -236,6 +236,8 @@ class EngineShardSet { template void RunBlockingInParallel(U&& func); private: + void InitThreadLocal(util::ProactorBase* pb, bool update_db_time); + util::ProactorPool* pp_; std::vector shard_queue_; }; @@ -276,4 +278,7 @@ inline ShardId Shard(std::string_view v, ShardId shard_num) { return hash % shard_num; } + +extern EngineShardSet* shard_set; + } // namespace dfly diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index d9dced1..de02d03 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -228,7 +228,7 @@ uint64_t ScanGeneric(uint64_t cursor, const ScanOpts& scan_opts, StringVec* keys ConnectionContext* cntx) { ShardId sid = cursor % 1024; - EngineShardSet* ess = cntx->shard_set; + EngineShardSet* ess = shard_set; unsigned shard_count = ess->size(); // Dash table returns a cursor with its right byte empty. We will use it @@ -489,7 +489,7 @@ void GenericFamily::Select(CmdArgList args, ConnectionContext* cntx) { shard->db_slice().ActivateDb(index); return OpStatus::OK; }; - cntx->shard_set->RunBriefInParallel(std::move(cb)); + shard_set->RunBriefInParallel(std::move(cb)); return (*cntx)->SendOk(); } @@ -529,7 +529,7 @@ OpResult GenericFamily::RenameGeneric(CmdArgList args, bool skip_exist_des } transaction->Schedule(); - unsigned shard_count = transaction->shard_set()->size(); + unsigned shard_count = shard_set->size(); Renamer renamer{transaction->db_index(), Shard(key[0], shard_count)}; // Phase 1 -> Fetch keys from both shards. diff --git a/src/server/list_family.cc b/src/server/list_family.cc index a3937cf..f603296 100644 --- a/src/server/list_family.cc +++ b/src/server/list_family.cc @@ -127,7 +127,7 @@ OpResult FindFirst(Transaction* trans) { // Holds Find results: (iterator to a found key, and its index in the passed arguments). // See DbSlice::FindFirst for more details. // spans all the shards for now. - std::vector> find_res(trans->shard_set()->size()); + std::vector> find_res(shard_set->size()); fill(find_res.begin(), find_res.end(), OpStatus::KEY_NOTFOUND); auto cb = [&find_res](auto* t, EngineShard* shard) { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 728c6bd..4eea48a 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -296,8 +296,11 @@ bool EvalValidator(CmdArgList args, ConnectionContext* cntx) { } // namespace -Service::Service(ProactorPool* pp) : pp_(*pp), shard_set_(pp), server_family_(this) { +Service::Service(ProactorPool* pp) : pp_(*pp), server_family_(this) { CHECK(pp); + CHECK(shard_set == NULL); + + shard_set = new EngineShardSet(pp); // We support less than 1024 threads and we support less than 1024 shards. // For example, Scan uses 10 bits in cursor to encode shard id it currently traverses. @@ -308,6 +311,8 @@ Service::Service(ProactorPool* pp) : pp_(*pp), shard_set_(pp), server_family_(th } Service::~Service() { + delete shard_set; + shard_set = nullptr; } void Service::Init(util::AcceptServer* acceptor, util::ListenerInterface* main_interface, @@ -315,14 +320,10 @@ void Service::Init(util::AcceptServer* acceptor, util::ListenerInterface* main_i InitRedisTables(); uint32_t shard_num = pp_.size() > 1 ? pp_.size() - 1 : pp_.size(); - shard_set_.Init(shard_num); + shard_set->Init(shard_num, !opts.disable_time_update); pp_.AwaitFiberOnAll([&](uint32_t index, ProactorBase* pb) { ServerState::tlocal()->Init(); - - if (index < shard_count()) { - shard_set_.InitThreadLocal(pb, !opts.disable_time_update); - } }); request_latency_usec.Init(&pp_); @@ -352,7 +353,7 @@ void Service::Shutdown() { GenericFamily::Shutdown(); cmd_req.Shutdown(); - shard_set_.RunBlockingInParallel([&](EngineShard*) { EngineShard::DestroyThreadLocal(); }); + shard_set->Shutdown(); // wait for all the pending callbacks to stop. boost::this_fiber::sleep_for(10ms); @@ -366,7 +367,7 @@ static void MultiSetError(ConnectionContext* cntx) { void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) { CHECK(!args.empty()); - DCHECK_NE(0u, shard_set_.size()) << "Init was not called"; + DCHECK_NE(0u, shard_set->size()) << "Init was not called"; ToUpper(&args[0]); @@ -489,7 +490,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) DCHECK(dfly_cntx->transaction == nullptr); if (IsTransactional(cid)) { - dist_trans.reset(new Transaction{cid, &shard_set_}); + dist_trans.reset(new Transaction{cid}); OpStatus st = dist_trans->InitByArgs(dfly_cntx->conn_state.db_index, args); if (st != OpStatus::OK) return (*cntx)->SendError(st); @@ -619,7 +620,6 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va facade::ConnectionContext* Service::CreateContext(util::FiberSocketBase* peer, facade::Connection* owner) { ConnectionContext* res = new ConnectionContext{peer, owner}; - res->shard_set = &shard_set(); res->req_auth = IsPassProtected(); // a bit of a hack. I set up breaker callback here for the owner. @@ -651,7 +651,7 @@ bool Service::IsLocked(DbIndex db_index, std::string_view key) const { bool Service::IsShardSetLocked() const { std::atomic_uint res{0}; - shard_set_.RunBriefInParallel([&](EngineShard* shard) { + shard_set->RunBriefInParallel([&](EngineShard* shard) { bool unlocked = shard->shard_lock()->Check(IntentLock::SHARED); res.fetch_add(!unlocked, memory_order_relaxed); }); @@ -897,14 +897,14 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) { auto cb = [&] { return EngineShard::tlocal()->channel_slice().FetchSubscribers(channel); }; - vector subsriber_arr = shard_set_.Await(sid, std::move(cb)); + vector subsriber_arr = shard_set->Await(sid, std::move(cb)); atomic_uint32_t published{0}; if (!subsriber_arr.empty()) { sort(subsriber_arr.begin(), subsriber_arr.end(), [](const auto& left, const auto& right) { return left.thread_id < right.thread_id; }); - vector slices(shard_set_.pool()->size(), UINT_MAX); + vector slices(shard_set->pool()->size(), UINT_MAX); for (size_t i = 0; i < subsriber_arr.size(); ++i) { if (slices[subsriber_arr[i].thread_id] > i) { slices[subsriber_arr[i].thread_id] = i; @@ -932,7 +932,7 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) { } }; - shard_set_.pool()->Await(publish_cb); + shard_set->pool()->Await(publish_cb); bc.Wait(); // Wait for all the messages to be sent. } diff --git a/src/server/main_service.h b/src/server/main_service.h index 69f3e51..ec9efe7 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -48,17 +48,13 @@ class Service : public facade::ServiceInterface { facade::ConnectionStats* GetThreadLocalConnectionStats() final; uint32_t shard_count() const { - return shard_set_.size(); + return shard_set->size(); } // Used by tests. bool IsLocked(DbIndex db_index, std::string_view key) const; bool IsShardSetLocked() const; - EngineShardSet& shard_set() { - return shard_set_; - } - util::ProactorPool& proactor_pool() { return pp_; } @@ -102,7 +98,6 @@ class Service : public facade::ServiceInterface { util::ProactorPool& pp_; - EngineShardSet shard_set_; ServerFamily server_family_; CommandRegistry registry_; diff --git a/src/server/rdb_test.cc b/src/server/rdb_test.cc index ed86dfc..3b5ba5c 100644 --- a/src/server/rdb_test.cc +++ b/src/server/rdb_test.cc @@ -72,14 +72,14 @@ TEST_F(RdbTest, Crc) { TEST_F(RdbTest, LoadEmpty) { io::FileSource fs = GetSource("empty.rdb"); - RdbLoader loader(ess_, NULL); + RdbLoader loader(shard_set, NULL); auto ec = loader.Load(&fs); CHECK(!ec); } TEST_F(RdbTest, LoadSmall6) { io::FileSource fs = GetSource("redis6_small.rdb"); - RdbLoader loader(ess_, service_->script_mgr()); + RdbLoader loader(shard_set, service_->script_mgr()); auto ec = loader.Load(&fs); ASSERT_FALSE(ec) << ec.message(); diff --git a/src/server/script_mgr.cc b/src/server/script_mgr.cc index cf9304b..597fda7 100644 --- a/src/server/script_mgr.cc +++ b/src/server/script_mgr.cc @@ -16,7 +16,7 @@ namespace dfly { using namespace std; using namespace facade; -ScriptMgr::ScriptMgr(EngineShardSet* ess) : ess_(ess) { +ScriptMgr::ScriptMgr() { } void ScriptMgr::Run(CmdArgList args, ConnectionContext* cntx) { diff --git a/src/server/script_mgr.h b/src/server/script_mgr.h index 7094e69..f3dfc7f 100644 --- a/src/server/script_mgr.h +++ b/src/server/script_mgr.h @@ -18,17 +18,16 @@ class EngineShardSet; // This class has a state through the lifetime of a server because it manipulates scripts class ScriptMgr { public: - ScriptMgr(EngineShardSet* ess); + ScriptMgr(); void Run(CmdArgList args, ConnectionContext* cntx); - bool InsertFunction(std::string_view sha, std::string_view body); + bool InsertFunction(std::string_view sha, std::string_view body); // Returns body as null-terminated c-string. NULL if sha is not found. const char* Find(std::string_view sha) const; private: - EngineShardSet* ess_; using ScriptKey = std::array; absl::flat_hash_map> db_; mutable ::boost::fibers::mutex mu_; diff --git a/src/server/server_family.cc b/src/server/server_family.cc index cead09d..3186653 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -83,10 +83,10 @@ string UnknowSubCmd(string_view subcmd, string cmd) { } // namespace -ServerFamily::ServerFamily(Service* service) : service_(*service), ess_(service->shard_set()) { +ServerFamily::ServerFamily(Service* service) : service_(*service) { start_time_ = time(NULL); lsinfo_.save_time = start_time_; - script_mgr_.reset(new ScriptMgr(&service->shard_set())); + script_mgr_.reset(new ScriptMgr()); } ServerFamily::~ServerFamily() { @@ -97,7 +97,7 @@ void ServerFamily::Init(util::AcceptServer* acceptor, util::ListenerInterface* m acceptor_ = acceptor; main_listener_ = main_listener; - pb_task_ = ess_.pool()->GetNextProactor(); + pb_task_ = shard_set->pool()->GetNextProactor(); auto cache_cb = [] { uint64_t sum = 0; const auto& stats = EngineShardSet::GetCachedStats(); @@ -225,7 +225,7 @@ error_code ServerFamily::DoSave(Transaction* trans, string* err_details) { unique_ptr<::io::WriteFile> wf(*res); auto start = absl::Now(); - RdbSaver saver{&ess_, wf.get()}; + RdbSaver saver{shard_set, wf.get()}; ec = saver.SaveHeader(); if (!ec) { @@ -294,7 +294,7 @@ string ServerFamily::LastSaveFile() const { void ServerFamily::DbSize(CmdArgList args, ConnectionContext* cntx) { atomic_ulong num_keys{0}; - ess_.RunBriefInParallel( + shard_set->RunBriefInParallel( [&](EngineShard* shard) { auto db_size = shard->db_slice().DbSize(cntx->conn_state.db_index); num_keys.fetch_add(db_size, memory_order_relaxed); @@ -384,7 +384,7 @@ void ServerFamily::Config(CmdArgList args, ConnectionContext* cntx) { return (*cntx)->SendStringArr(res); } else if (sub_cmd == "RESETSTAT") { - ess_.pool()->Await([](auto*) { + shard_set->pool()->Await([](auto*) { auto* stats = ServerState::tl_connection_stats(); stats->cmd_count_map.clear(); stats->err_count_map.clear(); diff --git a/src/server/server_family.h b/src/server/server_family.h index a12a41c..ccd438c 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -71,7 +71,7 @@ class ServerFamily { private: uint32_t shard_count() const { - return ess_.size(); + return shard_set->size(); } void Auth(CmdArgList args, ConnectionContext* cntx); @@ -99,7 +99,6 @@ class ServerFamily { uint32_t task_10ms_ = 0; Service& service_; - EngineShardSet& ess_; util::AcceptServer* acceptor_ = nullptr; util::ListenerInterface* main_listener_ = nullptr; diff --git a/src/server/set_family.cc b/src/server/set_family.cc index 6d1345b..8d6de28 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -815,7 +815,7 @@ void SetFamily::SPop(CmdArgList args, ConnectionContext* cntx) { } void SetFamily::SDiff(CmdArgList args, ConnectionContext* cntx) { - ResultStringVec result_set(cntx->transaction->shard_set()->size(), OpStatus::SKIPPED); + ResultStringVec result_set(shard_set->size(), OpStatus::SKIPPED); std::string_view src_key = ArgS(args, 1); ShardId src_shard = Shard(src_key, result_set.size()); @@ -846,7 +846,7 @@ void SetFamily::SDiff(CmdArgList args, ConnectionContext* cntx) { } void SetFamily::SDiffStore(CmdArgList args, ConnectionContext* cntx) { - ResultStringVec result_set(cntx->transaction->shard_set()->size(), OpStatus::SKIPPED); + ResultStringVec result_set(shard_set->size(), OpStatus::SKIPPED); std::string_view dest_key = ArgS(args, 1); ShardId dest_shard = Shard(dest_key, result_set.size()); std::string_view src_key = ArgS(args, 2); @@ -917,7 +917,7 @@ void SetFamily::SMembers(CmdArgList args, ConnectionContext* cntx) { } void SetFamily::SInter(CmdArgList args, ConnectionContext* cntx) { - ResultStringVec result_set(cntx->transaction->shard_set()->size(), OpStatus::SKIPPED); + ResultStringVec result_set(shard_set->size(), OpStatus::SKIPPED); auto cb = [&](Transaction* t, EngineShard* shard) { result_set[shard->shard_id()] = OpInter(t, shard, false); @@ -939,7 +939,7 @@ void SetFamily::SInter(CmdArgList args, ConnectionContext* cntx) { } void SetFamily::SInterStore(CmdArgList args, ConnectionContext* cntx) { - ResultStringVec result_set(cntx->transaction->shard_set()->size(), OpStatus::SKIPPED); + ResultStringVec result_set(shard_set->size(), OpStatus::SKIPPED); std::string_view dest_key = ArgS(args, 1); ShardId dest_shard = Shard(dest_key, result_set.size()); atomic_uint32_t inter_shard_cnt{0}; @@ -979,7 +979,7 @@ void SetFamily::SInterStore(CmdArgList args, ConnectionContext* cntx) { } void SetFamily::SUnion(CmdArgList args, ConnectionContext* cntx) { - ResultStringVec result_set(cntx->transaction->shard_set()->size()); + ResultStringVec result_set(shard_set->size()); auto cb = [&](Transaction* t, EngineShard* shard) { ArgSlice largs = t->ShardArgsInShard(shard->shard_id()); @@ -1002,7 +1002,7 @@ void SetFamily::SUnion(CmdArgList args, ConnectionContext* cntx) { } void SetFamily::SUnionStore(CmdArgList args, ConnectionContext* cntx) { - ResultStringVec result_set(cntx->transaction->shard_set()->size(), OpStatus::SKIPPED); + ResultStringVec result_set(shard_set->size(), OpStatus::SKIPPED); std::string_view dest_key = ArgS(args, 1); ShardId dest_shard = Shard(dest_key, result_set.size()); diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 1468195..7563342 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -36,7 +36,6 @@ DEFINE_VARZ(VarzQps, get_qps); constexpr uint32_t kMaxStrLen = 1 << 28; constexpr uint32_t kMinTieredLen = TieredStorage::kMinBlobLen; - string GetString(EngineShard* shard, const PrimeValue& pv) { string res; if (pv.IsExternal()) { @@ -157,7 +156,11 @@ OpResult SetCmd::Set(const SetParams& params, std::string_view key, std::s PrimeValue tvalue{value}; tvalue.SetFlag(params.memcache_flags != 0); uint64_t at_ms = params.expire_after_ms ? params.expire_after_ms + db_slice_.Now() : 0; - it = db_slice_.AddNew(params.db_index, key, std::move(tvalue), at_ms); + try { + it = db_slice_.AddNew(params.db_index, key, std::move(tvalue), at_ms); + } catch (bad_alloc& e) { + return OpStatus::OUT_OF_MEMORY; + } EngineShard* shard = db_slice_.shard_owner(); @@ -208,7 +211,6 @@ OpStatus SetCmd::SetExisting(const SetParams& params, PrimeIterator it, ExpireIt // overwrite existing entry. prime_value.SetString(value); - if (value.size() >= kMinTieredLen) { // external storage enabled. EngineShard* shard = db_slice_.shard_owner(); @@ -331,8 +333,8 @@ void StringFamily::GetSet(CmdArgList args, ConnectionContext* cntx) { SetCmd::SetParams sparams{cntx->db_index()}; sparams.prev_val = &prev_val; - ShardId sid = Shard(key, cntx->shard_set->size()); - OpResult result = cntx->shard_set->Await(sid, [&] { + ShardId sid = Shard(key, shard_set->size()); + OpResult result = shard_set->Await(sid, [&] { EngineShard* es = EngineShard::tlocal(); SetCmd cmd(&es->db_slice()); @@ -518,7 +520,7 @@ void StringFamily::MGet(CmdArgList args, ConnectionContext* cntx) { DCHECK_GT(args.size(), 1U); Transaction* transaction = cntx->transaction; - unsigned shard_count = transaction->shard_set()->size(); + unsigned shard_count = shard_set->size(); std::vector mget_resp(shard_count); ConnectionContext* dfly_cntx = static_cast(cntx); diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index ed8c73a..5c30379 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -63,14 +63,13 @@ void BaseFamilyTest::SetUp() { Service::InitOpts opts; opts.disable_time_update = true; service_->Init(nullptr, nullptr, opts); - ess_ = &service_->shard_set(); expire_now_ = absl::GetCurrentTimeNanos() / 1000000; auto cb = [&](EngineShard* s) { s->db_slice().UpdateExpireBase(expire_now_ - 1000, 0); s->db_slice().UpdateExpireClock(expire_now_); }; - ess_->RunBriefInParallel(cb); + shard_set->RunBriefInParallel(cb); const TestInfo* const test_info = UnitTest::GetInstance()->current_test_info(); LOG(INFO) << "Starting " << test_info->name(); @@ -88,7 +87,7 @@ void BaseFamilyTest::TearDown() { // ts is ms void BaseFamilyTest::UpdateTime(uint64_t ms) { auto cb = [ms](EngineShard* s) { s->db_slice().UpdateExpireClock(ms); }; - ess_->RunBriefInParallel(cb); + shard_set->RunBriefInParallel(cb); } RespExpr BaseFamilyTest::Run(initializer_list list) { @@ -105,7 +104,6 @@ RespExpr BaseFamilyTest::Run(std::string_view id, std::initializer_listArgs(list); auto& context = conn->cmd_cntx; - context.shard_set = ess_; DCHECK(context.transaction == nullptr); @@ -144,7 +142,6 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, string_view key, string_view va TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId()); auto& context = conn->cmd_cntx; - context.shard_set = ess_; DCHECK(context.transaction == nullptr); @@ -166,7 +163,6 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResp TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId()); auto& context = conn->cmd_cntx; - context.shard_set = ess_; service_->DispatchMC(cmd, string_view{}, &context); @@ -193,7 +189,6 @@ auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_listcmd_cntx; - context.shard_set = ess_; service_->DispatchMC(cmd, string_view{}, &context); @@ -249,7 +244,7 @@ RespVec BaseFamilyTest::TestConnWrapper::ParseResponse() { } bool BaseFamilyTest::IsLocked(DbIndex db_index, std::string_view key) const { - ShardId sid = Shard(key, ess_->size()); + ShardId sid = Shard(key, shard_set->size()); KeyLockArgs args; args.db_index = db_index; args.args = ArgSlice{&key, 1}; diff --git a/src/server/test_utils.h b/src/server/test_utils.h index 08dc433..2fb3dba 100644 --- a/src/server/test_utils.h +++ b/src/server/test_utils.h @@ -73,7 +73,6 @@ class BaseFamilyTest : public ::testing::Test { std::unique_ptr pp_; std::unique_ptr service_; - EngineShardSet* ess_ = nullptr; unsigned num_threads_ = 3; absl::flat_hash_map> connections_; diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 7a812c5..facb0f3 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -38,7 +38,7 @@ IntentLock::Mode Transaction::Mode() const { * @param ess * @param cs */ -Transaction::Transaction(const CommandId* cid, EngineShardSet* ess) : cid_(cid), ess_(ess) { +Transaction::Transaction(const CommandId* cid) : cid_(cid) { string_view cmd_name(cid_->name()); if (cmd_name == "EXEC" || cmd_name == "EVAL" || cmd_name == "EVALSHA") { multi_.reset(new Multi); @@ -78,7 +78,7 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) { db_index_ = index; if (IsGlobal()) { - unique_shard_cnt_ = ess_->size(); + unique_shard_cnt_ = shard_set->size(); shard_data_.resize(unique_shard_cnt_); return OpStatus::OK; } @@ -117,7 +117,7 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) { string_view key = args_.front(); unique_shard_cnt_ = 1; - unique_shard_id_ = Shard(key, ess_->size()); + unique_shard_id_ = Shard(key, shard_set->size()); if (needs_reverse_mapping) { reverse_index_.resize(args_.size()); @@ -129,7 +129,7 @@ OpStatus Transaction::InitByArgs(DbIndex index, CmdArgList args) { } // Our shard_data is not sparse, so we must allocate for all threads :( - shard_data_.resize(ess_->size()); + shard_data_.resize(shard_set->size()); CHECK(key_index.step == 1 || key_index.step == 2); DCHECK(key_index.step == 1 || (args.size() % 2) == 1); @@ -444,11 +444,11 @@ void Transaction::ScheduleInternal() { if (span_all) { is_active = [](uint32_t) { return true; }; - num_shards = ess_->size(); + num_shards = shard_set->size(); // Lock shards auto cb = [mode](EngineShard* shard) { shard->shard_lock()->Acquire(mode); }; - ess_->RunBriefInParallel(std::move(cb)); + shard_set->RunBriefInParallel(std::move(cb)); } else { num_shards = unique_shard_cnt_; DCHECK_GT(num_shards, 0u); @@ -470,7 +470,7 @@ void Transaction::ScheduleInternal() { lock_granted_cnt.fetch_add(res.second, memory_order_relaxed); }; - ess_->RunBriefInParallel(std::move(cb), is_active); + shard_set->RunBriefInParallel(std::move(cb), is_active); if (success.load(memory_order_acquire) == num_shards) { // We allow out of order execution only for single hop transactions. @@ -494,7 +494,7 @@ void Transaction::ScheduleInternal() { success.fetch_sub(CancelShardCb(shard), memory_order_relaxed); }; - ess_->RunBriefInParallel(std::move(cancel), is_active); + shard_set->RunBriefInParallel(std::move(cancel), is_active); CHECK_EQ(0u, success.load(memory_order_relaxed)); } @@ -548,7 +548,7 @@ OpStatus Transaction::ScheduleSingleHop(RunnableType cb) { } }; - ess_->Add(unique_shard_id_, std::move(schedule_cb)); // serves as a barrier. + shard_set->Add(unique_shard_id_, std::move(schedule_cb)); // serves as a barrier. } else { // Transaction spans multiple shards or it's global (like flushdb) or multi. if (!multi_) @@ -571,7 +571,7 @@ void Transaction::UnlockMulti() { DCHECK(multi_); using KeyList = vector>; - vector sharded_keys(ess_->size()); + vector sharded_keys(shard_set->size()); // It's LE and not EQ because there may be callbacks in progress that increase use_count_. DCHECK_LE(1u, use_count()); @@ -585,7 +585,7 @@ void Transaction::UnlockMulti() { DCHECK_EQ(prev, 0u); for (ShardId i = 0; i < shard_data_.size(); ++i) { - ess_->Add(i, [&] { UnlockMultiShardCb(sharded_keys, EngineShard::tlocal()); }); + shard_set->Add(i, [&] { UnlockMultiShardCb(sharded_keys, EngineShard::tlocal()); }); } WaitForShardCallbacks(); DCHECK_GE(use_count(), 1u); @@ -683,13 +683,13 @@ void Transaction::ExecuteAsync() { // IsArmedInShard is the protector of non-thread safe data. if (!is_global && unique_shard_cnt_ == 1) { - ess_->Add(unique_shard_id_, std::move(cb)); // serves as a barrier. + shard_set->Add(unique_shard_id_, std::move(cb)); // serves as a barrier. } else { for (ShardId i = 0; i < shard_data_.size(); ++i) { auto& sd = shard_data_[i]; if (!is_global && sd.arg_count == 0) continue; - ess_->Add(i, cb); // serves as a barrier. + shard_set->Add(i, cb); // serves as a barrier. } } } @@ -732,8 +732,8 @@ void Transaction::ExpireBlocking() { auto expire_cb = [this] { ExpireShardCb(EngineShard::tlocal()); }; if (unique_shard_cnt_ == 1) { - DCHECK_LT(unique_shard_id_, ess_->size()); - ess_->Add(unique_shard_id_, move(expire_cb)); + DCHECK_LT(unique_shard_id_, shard_set->size()); + shard_set->Add(unique_shard_id_, move(expire_cb)); } else { for (ShardId i = 0; i < shard_data_.size(); ++i) { auto& sd = shard_data_[i]; @@ -741,7 +741,7 @@ void Transaction::ExpireBlocking() { if (sd.arg_count == 0) continue; - ess_->Add(i, expire_cb); + shard_set->Add(i, expire_cb); } } @@ -963,7 +963,7 @@ bool Transaction::WaitOnWatch(const time_point& tp) { DCHECK_EQ(0, sd.local_mask & ARMED); if (sd.arg_count == 0) continue; - ess_->Add(i, converge_cb); + shard_set->Add(i, converge_cb); } // Wait for all callbacks to conclude. diff --git a/src/server/transaction.h b/src/server/transaction.h index 7a72eb1..8f5b98a 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -59,7 +59,7 @@ class Transaction { EXPIRED_Q = 0x40, // timed-out and should be garbage collected from the blocking queue. }; - Transaction(const CommandId* cid, EngineShardSet* ess); + explicit Transaction(const CommandId* cid); OpStatus InitByArgs(DbIndex index, CmdArgList args); @@ -156,10 +156,6 @@ class Transaction { return coordinator_state_ & COORD_OOO; } - EngineShardSet* shard_set() { - return ess_; - } - // Registers transaction into watched queue and blocks until a) either notification is received. // or b) tp is reached. If tp is time_point::max() then waits indefinitely. // Expects that the transaction had been scheduled before, and uses Execute(.., true) to register. @@ -287,7 +283,7 @@ class Transaction { std::unique_ptr multi_; // Initialized when the transaction is multi/exec. const CommandId* cid_; - EngineShardSet* ess_; + TxId txid_{0}; std::atomic notify_txid_{kuint64max}; diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 7466809..3025bca 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -1089,7 +1089,7 @@ void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) { return SendAtLeastOneKeyError(cntx); } - vector> maps(cntx->shard_set->size(), OpStatus::SKIPPED); + vector> maps(shard_set->size(), OpStatus::SKIPPED); auto cb = [&](Transaction* t, EngineShard* shard) { maps[shard->shard_id()] = @@ -1386,7 +1386,7 @@ void ZSetFamily::ZUnionStore(CmdArgList args, ConnectionContext* cntx) { return SendAtLeastOneKeyError(cntx); } - vector> maps(cntx->shard_set->size()); + vector> maps(shard_set->size()); auto cb = [&](Transaction* t, EngineShard* shard) { maps[shard->shard_id()] =