From 3ba6cca99ec6d3a7294dc22baf19529c2eb1a120 Mon Sep 17 00:00:00 2001 From: Jun FENG <99384777+6fj@users.noreply.github.com> Date: Thu, 23 Nov 2023 11:30:13 +0800 Subject: [PATCH] repo-sync-2023-11-23T11:26:31+0800 (#2) --- docker/build.sh | 5 +- psi/psi/utils/BUILD.bazel | 4 +- psi/psi/utils/arrow_csv_batch_provider.cc | 76 ++++++++++--------- psi/psi/utils/arrow_csv_batch_provider.h | 19 ++--- .../utils/arrow_csv_batch_provider_test.cc | 28 ++++--- psi/psi/utils/hash_bucket_cache.cc | 4 +- 6 files changed, 74 insertions(+), 62 deletions(-) diff --git a/docker/build.sh b/docker/build.sh index 96a4ca1..d51cc0c 100644 --- a/docker/build.sh +++ b/docker/build.sh @@ -67,8 +67,9 @@ DOCKER_REG="secretflow" IMAGE_TAG=${DOCKER_REG}/psi-anolis8:${VERSION} LATEST_TAG=${DOCKER_REG}/psi-anolis8:latest -if [[ LATEST -eq 0 ]]; then - echo -e "Build psi binary ${GREEN}PSI ${PSI_VERSION}${NO_COLOR}..." +echo -e "Build psi binary ${GREEN}PSI ${PSI_VERSION}${NO_COLOR}..." + +if [[ SKIP -eq 0 ]]; then docker run -it --rm --mount type=bind,source="$(pwd)/../../psi",target=/home/admin/dev/src -w /home/admin/dev --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow/release-ci:1.2 /home/admin/dev/src/docker/entry.sh echo -e "Finish building psi binary ${GREEN}${IMAGE_LITE_TAG}${NO_COLOR}" fi diff --git a/psi/psi/utils/BUILD.bazel b/psi/psi/utils/BUILD.bazel index c153795..ed62bec 100644 --- a/psi/psi/utils/BUILD.bazel +++ b/psi/psi/utils/BUILD.bazel @@ -42,8 +42,8 @@ psi_cc_library( srcs = ["hash_bucket_cache.cc"], hdrs = ["hash_bucket_cache.h"], deps = [ + ":arrow_csv_batch_provider", ":multiplex_disk_cache", - "//psi/psi/utils:batch_provider", "@com_google_absl//absl/strings", ], ) @@ -53,8 +53,8 @@ psi_cc_library( srcs = ["csv_checker.cc"], hdrs = ["csv_checker.h"], deps = [ + ":utils", "//psi/psi/io", - "//psi/psi/utils", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@org_apache_arrow//:arrow", diff --git a/psi/psi/utils/arrow_csv_batch_provider.cc b/psi/psi/utils/arrow_csv_batch_provider.cc index 87d791e..cab178e 100644 --- a/psi/psi/utils/arrow_csv_batch_provider.cc +++ b/psi/psi/utils/arrow_csv_batch_provider.cc @@ -21,55 +21,59 @@ #include "arrow/datum.h" #include "spdlog/spdlog.h" +#include "psi/psi/utils/utils.h" + namespace psi::psi { ArrowCsvBatchProvider::ArrowCsvBatchProvider( const std::string& file_path, const std::vector& keys, - const std::string& separator, size_t block_size) - : block_size_(block_size), - file_path_(file_path), - keys_(keys), - separator_(separator) { + size_t batch_size) + : batch_size_(batch_size), file_path_(file_path), keys_(keys) { Init(); } std::vector ArrowCsvBatchProvider::ReadNextBatch() { - std::shared_ptr batch; - arrow::Status status = reader_->ReadNext(&batch); - if (!status.ok()) { - YACL_THROW("Read csv error."); - } + std::vector res; - if (!batch) { - SPDLOG_INFO("Reach the end of csv file {}.", file_path_); - return {}; - } + while (res.size() < batch_size_) { + bool new_batch = false; - std::vector join_cols; + if (!batch_ || idx_in_batch_ >= batch_->num_rows()) { + arrow::Status status = reader_->ReadNext(&batch_); + if (!status.ok()) { + YACL_THROW("Read csv error."); + } - arrow::compute::CastOptions cast_options; + new_batch = true; + } - for (const auto& col : batch->columns()) { - join_cols.emplace_back( - arrow::compute::Cast(arrow::Datum(*col), arrow::utf8(), cast_options) - .ValueOrDie()); - } + if (!batch_) { + SPDLOG_INFO("Reach the end of csv file {}.", file_path_); + return res; + } - join_cols.emplace_back(arrow::MakeScalar(separator_)); + if (new_batch) { + idx_in_batch_ = 0; - arrow::Datum join_datum = - arrow::compute::CallFunction("binary_join_element_wise", join_cols) - .ValueOrDie(); + arrays_.clear(); - std::shared_ptr join_array = std::move(join_datum).make_array(); - auto str_array = std::dynamic_pointer_cast(join_array); + for (const auto& col : batch_->columns()) { + arrays_.emplace_back( + std::dynamic_pointer_cast(col)); + } + } - std::vector res; - for (int i = 0; i < batch->num_rows(); ++i) { - res.emplace_back(str_array->GetString(i)); - } + for (; idx_in_batch_ < batch_->num_rows() && res.size() < batch_size_; + idx_in_batch_++) { + std::vector values; + for (const auto& array : arrays_) { + values.emplace_back(array->Value(idx_in_batch_)); + } - row_cnt_ += batch->num_rows(); + res.emplace_back(KeysJoin(values)); + row_cnt_++; + } + } return res; } @@ -78,19 +82,21 @@ void ArrowCsvBatchProvider::Init() { YACL_ENFORCE(std::filesystem::exists(file_path_), "Input file {} doesn't exist.", file_path_); + YACL_ENFORCE(!keys_.empty(), "You must provide keys."); + arrow::io::IOContext io_context = arrow::io::default_io_context(); infile_ = arrow::io::ReadableFile::Open(file_path_, arrow::default_memory_pool()) .ValueOrDie(); auto read_options = arrow::csv::ReadOptions::Defaults(); - read_options.block_size = block_size_; auto parse_options = arrow::csv::ParseOptions::Defaults(); auto convert_options = arrow::csv::ConvertOptions::Defaults(); - if (!keys_.empty()) { - convert_options.include_columns = keys_; + for (const auto& key : keys_) { + convert_options.column_types[key] = arrow::utf8(); } + convert_options.include_columns = keys_; reader_ = arrow::csv::StreamingReader::Make(io_context, infile_, read_options, parse_options, convert_options) diff --git a/psi/psi/utils/arrow_csv_batch_provider.h b/psi/psi/utils/arrow_csv_batch_provider.h index 24d7c68..e9c9cfb 100644 --- a/psi/psi/utils/arrow_csv_batch_provider.h +++ b/psi/psi/utils/arrow_csv_batch_provider.h @@ -26,35 +26,36 @@ namespace psi::psi { class ArrowCsvBatchProvider : public IBasicBatchProvider { public: - // NOTE(junfeng): block_size is not col num of each batch, which by default is - // 1 << 20 (1 Mb). explicit ArrowCsvBatchProvider(const std::string& file_path, - const std::vector& keys = {}, - const std::string& separator = ",", - size_t block_size = 1 << 20); + const std::vector& keys, + size_t batch_size = 1 << 20); std::vector ReadNextBatch() override; [[nodiscard]] size_t row_cnt() const { return row_cnt_; } - [[nodiscard]] size_t batch_size() const { return block_size_; } + [[nodiscard]] size_t batch_size() const { return batch_size_; } private: void Init(); - const size_t block_size_; + const size_t batch_size_; const std::string file_path_; const std::vector keys_; - const std::string separator_; - size_t row_cnt_ = 0; std::shared_ptr infile_; std::shared_ptr reader_; + + std::shared_ptr batch_; + + int64_t idx_in_batch_ = 0; + + std::vector> arrays_; }; } // namespace psi::psi diff --git a/psi/psi/utils/arrow_csv_batch_provider_test.cc b/psi/psi/utils/arrow_csv_batch_provider_test.cc index 17a4814..24bdd5e 100644 --- a/psi/psi/utils/arrow_csv_batch_provider_test.cc +++ b/psi/psi/utils/arrow_csv_batch_provider_test.cc @@ -38,10 +38,13 @@ TEST(ArrowCsvBatchProvider, works) { file.close(); { - ArrowCsvBatchProvider provider(file_path); + ArrowCsvBatchProvider provider(file_path, {"id1", "id2", "id3"}, 1); EXPECT_EQ(provider.ReadNextBatch(), - std::vector( - {"1,one,first", "2,two,second", "3,three,third"})); + std::vector({"1,one,first"})); + EXPECT_EQ(provider.ReadNextBatch(), + std::vector({"2,two,second"})); + EXPECT_EQ(provider.ReadNextBatch(), + std::vector({"3,three,third"})); EXPECT_EQ(provider.row_cnt(), 3); EXPECT_TRUE(provider.ReadNextBatch().empty()); EXPECT_TRUE(provider.ReadNextBatch().empty()); @@ -49,22 +52,23 @@ TEST(ArrowCsvBatchProvider, works) { } { - ArrowCsvBatchProvider provider(file_path, {}, "#"); - EXPECT_EQ(provider.ReadNextBatch(), - std::vector( - {"1#one#first", "2#two#second", "3#three#third"})); - } - - { - ArrowCsvBatchProvider provider(file_path, {"id2", "id1"}); + ArrowCsvBatchProvider provider(file_path, {"id2", "id1"}, 3); EXPECT_EQ(provider.ReadNextBatch(), std::vector({"one,1", "two,2", "three,3"})); + EXPECT_EQ(provider.row_cnt(), 3); + EXPECT_TRUE(provider.ReadNextBatch().empty()); + EXPECT_TRUE(provider.ReadNextBatch().empty()); + EXPECT_EQ(provider.row_cnt(), 3); } { - ArrowCsvBatchProvider provider(file_path, {"id3"}); + ArrowCsvBatchProvider provider(file_path, {"id3"}, 5); EXPECT_EQ(provider.ReadNextBatch(), std::vector({"first", "second", "third"})); + EXPECT_EQ(provider.row_cnt(), 3); + EXPECT_TRUE(provider.ReadNextBatch().empty()); + EXPECT_TRUE(provider.ReadNextBatch().empty()); + EXPECT_EQ(provider.row_cnt(), 3); } std::error_code ec; diff --git a/psi/psi/utils/hash_bucket_cache.cc b/psi/psi/utils/hash_bucket_cache.cc index dd2deb1..e61f3ca 100644 --- a/psi/psi/utils/hash_bucket_cache.cc +++ b/psi/psi/utils/hash_bucket_cache.cc @@ -22,7 +22,7 @@ #include "absl/strings/str_split.h" #include "spdlog/spdlog.h" -#include "psi/psi/utils/batch_provider.h" +#include "psi/psi/utils/arrow_csv_batch_provider.h" namespace psi::psi { @@ -80,7 +80,7 @@ std::unique_ptr CreateCacheFromCsv( auto bucket_cache = std::make_unique(cache_dir, bucket_num, use_scoped_tmp_dir); - auto batch_provider = std::make_unique( + auto batch_provider = std::make_unique( csv_path, schema_names, read_batch_size); while (true) { auto items = batch_provider->ReadNextBatch();