diff --git a/src/server/json_family.cc b/src/server/json_family.cc index 36a89ef..451452e 100644 --- a/src/server/json_family.cc +++ b/src/server/json_family.cc @@ -141,28 +141,37 @@ bool JsonErrorHandler(json_errc ec, const ser_context&) { return false; } -OpResult GetJson(const OpArgs& op_args, string_view key) { - OpResult it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_STRING); - if (!it_res.ok()) - return it_res.status(); - +optional ConstructJsonFromString(string_view val) { error_code ec; json_decoder decoder; - const PrimeValue& pv = it_res.value()->second; - - string val = GetString(op_args.shard, pv); basic_json_parser parser(basic_json_decode_options{}, &JsonErrorHandler); parser.update(val); parser.finish_parse(decoder, ec); if (!decoder.is_valid()) { - return OpStatus::SYNTAX_ERR; + return nullopt; } return decoder.get_result(); } +OpResult GetJson(const OpArgs& op_args, string_view key) { + OpResult it_res = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_STRING); + if (!it_res.ok()) + return it_res.status(); + + const PrimeValue& pv = it_res.value()->second; + + string val = GetString(op_args.shard, pv); + optional j = ConstructJsonFromString(val); + if (!j) { + return OpStatus::SYNTAX_ERR; + } + + return *j; +} + // Returns the index of the next right bracket optional GetNextIndex(string_view str) { size_t current_idx = 0; @@ -692,8 +701,113 @@ OpResult> OpArrTrim(const OpArgs& op_args, string_view key, str return vec; } +// Returns numeric vector that represents the new length of the array at each path. +OpResult> OpArrInsert(const OpArgs& op_args, string_view key, string_view path, + int index, const vector& new_values) { + OpResult result = GetJson(op_args, key); + if (!result) { + return result.status(); + } + + bool out_of_boundaries_encountered = false; + vector vec; + // Insert user-supplied value into the supplied index that should be valid. + // If at least one index isn't valid within an array in the json doc, the operation is discarded. + // Negative indexes start from the end of the array. + auto cb = [&](const string& path, json& val) { + if (out_of_boundaries_encountered) { + return; + } + + if (!val.is_array()) { + vec.emplace_back(nullopt); + return; + } + + size_t removal_index; + if (index < 0) { + if (val.empty()) { + out_of_boundaries_encountered = true; + return; + } + + int temp_index = index + val.size(); + if (temp_index < 0) { + out_of_boundaries_encountered = true; + return; + } + + removal_index = temp_index; + } else { + if ((size_t)index > val.size()) { + out_of_boundaries_encountered = true; + return; + } + + removal_index = index; + } + + auto it = next(val.array_range().begin(), removal_index); + for (auto& new_val : new_values) { + it = val.insert(it, new_val); + it++; + } + + vec.emplace_back(val.size()); + }; + + json j = result.value(); + error_code ec = JsonReplace(j, path, cb); + if (ec) { + VLOG(1) << "Failed to evaluate expression on json with error: " << ec.message(); + return OpStatus::SYNTAX_ERR; + } + + if (out_of_boundaries_encountered) { + return OpStatus::OUT_OF_RANGE; + } + + SetString(op_args, key, j.as_string()); + return vec; +} + } // namespace +void JsonFamily::ArrInsert(CmdArgList args, ConnectionContext* cntx) { + string_view key = ArgS(args, 1); + string_view path = ArgS(args, 2); + int index = -1; + + if (!absl::SimpleAtoi(ArgS(args, 3), &index)) { + VLOG(1) << "Failed to convert the following value to numeric: " << ArgS(args, 3); + (*cntx)->SendError(kInvalidIntErr); + return; + } + + vector new_values; + for (size_t i = 4; i < args.size(); i++) { + optional val = ConstructJsonFromString(ArgS(args, i)); + if (!val) { + (*cntx)->SendError(kSyntaxErr); + return; + } + + new_values.emplace_back(move(*val)); + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpArrInsert(t->GetOpArgs(shard), key, path, index, new_values); + }; + + Transaction* trans = cntx->transaction; + OpResult> result = trans->ScheduleSingleHopT(move(cb)); + if (result) { + PrintOptVec(cntx, result); + } else { + (*cntx)->SendError(result.status()); + } +} + void JsonFamily::ArrTrim(CmdArgList args, ConnectionContext* cntx) { string_view key = ArgS(args, 1); string_view path = ArgS(args, 2); @@ -1102,6 +1216,8 @@ void JsonFamily::Register(CommandRegistry* registry) { *registry << CI{"JSON.CLEAR", CO::WRITE | CO::DENYOOM | CO::FAST, 3, 1, 1, 1}.HFUNC(Clear); *registry << CI{"JSON.ARRPOP", CO::WRITE | CO::DENYOOM | CO::FAST, -3, 1, 1, 1}.HFUNC(ArrPop); *registry << CI{"JSON.ARRTRIM", CO::WRITE | CO::DENYOOM | CO::FAST, 5, 1, 1, 1}.HFUNC(ArrTrim); + *registry << CI{"JSON.ARRINSERT", CO::WRITE | CO::DENYOOM | CO::FAST, -4, 1, 1, 1}.HFUNC( + ArrInsert); } } // namespace dfly diff --git a/src/server/json_family.h b/src/server/json_family.h index 62eda3e..69ed932 100644 --- a/src/server/json_family.h +++ b/src/server/json_family.h @@ -34,6 +34,7 @@ class JsonFamily { static void Clear(CmdArgList args, ConnectionContext* cntx); static void ArrPop(CmdArgList args, ConnectionContext* cntx); static void ArrTrim(CmdArgList args, ConnectionContext* cntx); + static void ArrInsert(CmdArgList args, ConnectionContext* cntx); }; } // namespace dfly diff --git a/src/server/json_family_test.cc b/src/server/json_family_test.cc index f25de51..2939193 100644 --- a/src/server/json_family_test.cc +++ b/src/server/json_family_test.cc @@ -746,4 +746,34 @@ TEST_F(JsonFamilyTest, ArrTrim) { EXPECT_EQ(resp, R"({"a":[1,3,2],"nested":{"a":false}})"); } +TEST_F(JsonFamilyTest, ArrInsert) { + string json = R"( + [[], ["a"], ["a", "b"]] + )"; + + auto resp = Run({"set", "json", json}); + ASSERT_THAT(resp, "OK"); + + resp = Run({"JSON.ARRINSERT", "json", "$[*]", "0", R"("a")"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); + EXPECT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(2), IntArg(3))); + + resp = Run({"GET", "json"}); + EXPECT_EQ(resp, R"([["a"],["a","a"],["a","a","b"]])"); + + resp = Run({"JSON.ARRINSERT", "json", "$[*]", "-1", R"("b")"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); + EXPECT_THAT(resp.GetVec(), ElementsAre(IntArg(2), IntArg(3), IntArg(4))); + + resp = Run({"GET", "json"}); + EXPECT_EQ(resp, R"([["b","a"],["a","b","a"],["a","a","b","b"]])"); + + resp = Run({"JSON.ARRINSERT", "json", "$[*]", "1", R"("c")"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); + EXPECT_THAT(resp.GetVec(), ElementsAre(IntArg(3), IntArg(4), IntArg(5))); + + resp = Run({"GET", "json"}); + EXPECT_EQ(resp, R"([["b","c","a"],["a","c","b","a"],["a","c","a","b","b"]])"); +} + } // namespace dfly