From 25becd2d4322366e93e6b9b7871eea5ffaa7567c Mon Sep 17 00:00:00 2001 From: Zacharya Date: Mon, 4 Jul 2022 15:57:00 +0300 Subject: [PATCH] feat(pubsub): implement pubsub command close #90 (#175) * feat(pubsub): implement pubsub command * fix(pubsub): code review * fix(pubsub): code review * fix(pubsub): code review --- .gitignore | 1 + src/server/channel_slice.cc | 19 ++++++++++- src/server/channel_slice.h | 3 ++ src/server/command_registry.h | 3 ++ src/server/main_service.cc | 59 ++++++++++++++++++++++++++++++++++- src/server/main_service.h | 4 +++ 6 files changed, 87 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 5856f2d..2761022 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ genfiles/* *.pyc /CMakeLists.txt.user _deps +releases diff --git a/src/server/channel_slice.cc b/src/server/channel_slice.cc index cd5d1cd..650190d 100644 --- a/src/server/channel_slice.cc +++ b/src/server/channel_slice.cc @@ -72,7 +72,7 @@ auto ChannelSlice::FetchSubscribers(string_view channel) -> vector { } void ChannelSlice::CopySubscribers(const SubscribeMap& src, const std::string& pattern, - vector* dest) { + vector* dest) { for (const auto& sub : src) { ConnectionContext* cntx = sub.first; CHECK(cntx->conn_state.subscribe_info); @@ -85,4 +85,21 @@ void ChannelSlice::CopySubscribers(const SubscribeMap& src, const std::string& p } } +vector ChannelSlice::ListChannels(const string_view pattern) const { + vector res; + for (const auto& k_v : channels_) { + const string& channel = k_v.first; + + if (pattern.empty() || stringmatchlen(pattern.data(), pattern.size(), channel.data(), channel.size(), 0) == 1) { + res.push_back(channel); + } + } + + return res; +} + +size_t ChannelSlice::PatternCount() const { + return patterns_.size(); +} + } // namespace dfly diff --git a/src/server/channel_slice.h b/src/server/channel_slice.h index f2f3d5e..9d8d861 100644 --- a/src/server/channel_slice.h +++ b/src/server/channel_slice.h @@ -40,6 +40,9 @@ class ChannelSlice { void AddGlobPattern(std::string_view pattern, ConnectionContext* me, uint32_t thread_id); void RemoveGlobPattern(std::string_view pattern, ConnectionContext* me); + std::vector ListChannels(const std::string_view pattern) const; + size_t PatternCount() const; + private: struct SubscriberInternal { uint32_t thread_id; // proactor thread id. diff --git a/src/server/command_registry.h b/src/server/command_registry.h index 1c082f3..1441ebb 100644 --- a/src/server/command_registry.h +++ b/src/server/command_registry.h @@ -52,6 +52,9 @@ class CommandId { /** * @brief Construct a new Command Id object * + * When creating a new command use the https://github.com/redis/redis/tree/unstable/src/commands + * files to find the right arguments. + * * @param name * @param mask * @param arity - positive if command has fixed number of required arguments diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 2202612..6eb779c 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1038,6 +1038,62 @@ void Service::Function(CmdArgList args, ConnectionContext* cntx) { return (*cntx)->SendError(err, kSyntaxErrType); } +void Service::PubsubChannels(string_view pattern, ConnectionContext* cntx) { + vector> result_set(shard_set->size()); + + shard_set->RunBriefInParallel([&](EngineShard* shard) { + result_set[shard->shard_id()] = shard->channel_slice().ListChannels(pattern); + }); + + vector union_set; + for (auto&& v : result_set) { + union_set.insert(union_set.end(), v.begin(), v.end()); + } + + (*cntx)->SendStringArr(union_set); +} + +void Service::PubsubPatterns(ConnectionContext* cntx) { + size_t pattern_count = shard_set->Await(0, [&] { return EngineShard::tlocal()->channel_slice().PatternCount(); }); + (*cntx)->SendLong(pattern_count); +} + +void Service::Pubsub(CmdArgList args, ConnectionContext* cntx) { + if (args.size() < 2) { + (*cntx)->SendError(WrongNumArgsError(ArgS(args, 0))); + return; + } + + string_view subcmd = ArgS(args, 1); + + if (subcmd == "HELP") { + string_view help_arr[] = { + "PUBSUB [ [value] [opt] ...]. Subcommands are:", + "CHANNELS []", + "\tReturn the currently active channels matching a (default: '*').", + "NUMPAT", + "\tReturn number of subscriptions to patterns.", + "HELP", + "\tPrints this help."}; + + (*cntx)->SendSimpleStrArr(help_arr, ABSL_ARRAYSIZE(help_arr)); + return; + } + + if (subcmd == "CHANNELS") { + string pattern; + if (args.size() > 2) { + pattern = ArgS(args, 2); + } + + PubsubChannels(pattern, cntx); + } else if (subcmd == "NUMPAT") { + PubsubPatterns(cntx); + } else { + (*cntx)->SendError(UnknownSubCmd(subcmd, "PUBSUB")); + } +} + VarzValue::Map Service::GetVarzStats() { VarzValue::Map res; @@ -1095,7 +1151,8 @@ void Service::RegisterCommands() { << CI{"UNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -1, 0, 0, 0}.MFUNC(Unsubscribe) << CI{"PSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0}.MFUNC(PSubscribe) << CI{"PUNSUBSCRIBE", CO::NOSCRIPT | CO::LOADING, -1, 0, 0, 0}.MFUNC(PUnsubscribe) - << CI{"FUNCTION", CO::NOSCRIPT, 2, 0, 0, 0}.MFUNC(Function); + << CI{"FUNCTION", CO::NOSCRIPT, 2, 0, 0, 0}.MFUNC(Function) + << CI{"PUBSUB", CO::LOADING | CO::FAST, -1, 0, 0, 0}.MFUNC(Pubsub); StreamFamily::Register(®istry_); StringFamily::Register(®istry_); diff --git a/src/server/main_service.h b/src/server/main_service.h index cb40ea3..cff25dd 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -98,6 +98,10 @@ class Service : public facade::ServiceInterface { void PUnsubscribe(CmdArgList args, ConnectionContext* cntx); void Function(CmdArgList args, ConnectionContext* cntx); + void Pubsub(CmdArgList args, ConnectionContext* cntx); + void PubsubChannels(std::string_view pattern, ConnectionContext* cntx); + void PubsubPatterns(ConnectionContext* cntx); + struct EvalArgs { std::string_view sha; // only one of them is defined. CmdArgList keys, args;