Skip to content

Commit

Permalink
repo-sync-2023-11-23T11:26:31+0800 (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
6fj committed Nov 23, 2023
1 parent f6c796d commit 3ba6cca
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 62 deletions.
5 changes: 3 additions & 2 deletions docker/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ DOCKER_REG="secretflow"
IMAGE_TAG=${DOCKER_REG}/psi-anolis8:${VERSION}
LATEST_TAG=${DOCKER_REG}/psi-anolis8:latest

if [[ LATEST -eq 0 ]]; then
echo -e "Build psi binary ${GREEN}PSI ${PSI_VERSION}${NO_COLOR}..."
echo -e "Build psi binary ${GREEN}PSI ${PSI_VERSION}${NO_COLOR}..."

if [[ SKIP -eq 0 ]]; then
docker run -it --rm --mount type=bind,source="$(pwd)/../../psi",target=/home/admin/dev/src -w /home/admin/dev --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow/release-ci:1.2 /home/admin/dev/src/docker/entry.sh
echo -e "Finish building psi binary ${GREEN}${IMAGE_LITE_TAG}${NO_COLOR}"
fi
Expand Down
4 changes: 2 additions & 2 deletions psi/psi/utils/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ psi_cc_library(
srcs = ["hash_bucket_cache.cc"],
hdrs = ["hash_bucket_cache.h"],
deps = [
":arrow_csv_batch_provider",
":multiplex_disk_cache",
"//psi/psi/utils:batch_provider",
"@com_google_absl//absl/strings",
],
)
Expand All @@ -53,8 +53,8 @@ psi_cc_library(
srcs = ["csv_checker.cc"],
hdrs = ["csv_checker.h"],
deps = [
":utils",
"//psi/psi/io",
"//psi/psi/utils",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@org_apache_arrow//:arrow",
Expand Down
76 changes: 41 additions & 35 deletions psi/psi/utils/arrow_csv_batch_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,59 @@
#include "arrow/datum.h"
#include "spdlog/spdlog.h"

#include "psi/psi/utils/utils.h"

namespace psi::psi {

ArrowCsvBatchProvider::ArrowCsvBatchProvider(
const std::string& file_path, const std::vector<std::string>& keys,
const std::string& separator, size_t block_size)
: block_size_(block_size),
file_path_(file_path),
keys_(keys),
separator_(separator) {
size_t batch_size)
: batch_size_(batch_size), file_path_(file_path), keys_(keys) {
Init();
}

std::vector<std::string> ArrowCsvBatchProvider::ReadNextBatch() {
std::shared_ptr<arrow::RecordBatch> batch;
arrow::Status status = reader_->ReadNext(&batch);
if (!status.ok()) {
YACL_THROW("Read csv error.");
}
std::vector<std::string> res;

if (!batch) {
SPDLOG_INFO("Reach the end of csv file {}.", file_path_);
return {};
}
while (res.size() < batch_size_) {
bool new_batch = false;

std::vector<arrow::Datum> join_cols;
if (!batch_ || idx_in_batch_ >= batch_->num_rows()) {
arrow::Status status = reader_->ReadNext(&batch_);
if (!status.ok()) {
YACL_THROW("Read csv error.");
}

arrow::compute::CastOptions cast_options;
new_batch = true;
}

for (const auto& col : batch->columns()) {
join_cols.emplace_back(
arrow::compute::Cast(arrow::Datum(*col), arrow::utf8(), cast_options)
.ValueOrDie());
}
if (!batch_) {
SPDLOG_INFO("Reach the end of csv file {}.", file_path_);
return res;
}

join_cols.emplace_back(arrow::MakeScalar(separator_));
if (new_batch) {
idx_in_batch_ = 0;

arrow::Datum join_datum =
arrow::compute::CallFunction("binary_join_element_wise", join_cols)
.ValueOrDie();
arrays_.clear();

std::shared_ptr<arrow::Array> join_array = std::move(join_datum).make_array();
auto str_array = std::dynamic_pointer_cast<arrow::StringArray>(join_array);
for (const auto& col : batch_->columns()) {
arrays_.emplace_back(
std::dynamic_pointer_cast<arrow::StringArray>(col));
}
}

std::vector<std::string> res;
for (int i = 0; i < batch->num_rows(); ++i) {
res.emplace_back(str_array->GetString(i));
}
for (; idx_in_batch_ < batch_->num_rows() && res.size() < batch_size_;
idx_in_batch_++) {
std::vector<absl::string_view> values;
for (const auto& array : arrays_) {
values.emplace_back(array->Value(idx_in_batch_));
}

row_cnt_ += batch->num_rows();
res.emplace_back(KeysJoin(values));
row_cnt_++;
}
}

return res;
}
Expand All @@ -78,19 +82,21 @@ void ArrowCsvBatchProvider::Init() {
YACL_ENFORCE(std::filesystem::exists(file_path_),
"Input file {} doesn't exist.", file_path_);

YACL_ENFORCE(!keys_.empty(), "You must provide keys.");

arrow::io::IOContext io_context = arrow::io::default_io_context();
infile_ =
arrow::io::ReadableFile::Open(file_path_, arrow::default_memory_pool())
.ValueOrDie();

auto read_options = arrow::csv::ReadOptions::Defaults();
read_options.block_size = block_size_;
auto parse_options = arrow::csv::ParseOptions::Defaults();
auto convert_options = arrow::csv::ConvertOptions::Defaults();

if (!keys_.empty()) {
convert_options.include_columns = keys_;
for (const auto& key : keys_) {
convert_options.column_types[key] = arrow::utf8();
}
convert_options.include_columns = keys_;

reader_ = arrow::csv::StreamingReader::Make(io_context, infile_, read_options,
parse_options, convert_options)
Expand Down
19 changes: 10 additions & 9 deletions psi/psi/utils/arrow_csv_batch_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,36 @@ namespace psi::psi {

class ArrowCsvBatchProvider : public IBasicBatchProvider {
public:
// NOTE(junfeng): block_size is not col num of each batch, which by default is
// 1 << 20 (1 Mb).
explicit ArrowCsvBatchProvider(const std::string& file_path,
const std::vector<std::string>& keys = {},
const std::string& separator = ",",
size_t block_size = 1 << 20);
const std::vector<std::string>& keys,
size_t batch_size = 1 << 20);

std::vector<std::string> ReadNextBatch() override;

[[nodiscard]] size_t row_cnt() const { return row_cnt_; }

[[nodiscard]] size_t batch_size() const { return block_size_; }
[[nodiscard]] size_t batch_size() const { return batch_size_; }

private:
void Init();

const size_t block_size_;
const size_t batch_size_;

const std::string file_path_;

const std::vector<std::string> keys_;

const std::string separator_;

size_t row_cnt_ = 0;

std::shared_ptr<arrow::io::ReadableFile> infile_;

std::shared_ptr<arrow::csv::StreamingReader> reader_;

std::shared_ptr<arrow::RecordBatch> batch_;

int64_t idx_in_batch_ = 0;

std::vector<std::shared_ptr<arrow::StringArray>> arrays_;
};

} // namespace psi::psi
28 changes: 16 additions & 12 deletions psi/psi/utils/arrow_csv_batch_provider_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,33 +38,37 @@ TEST(ArrowCsvBatchProvider, works) {
file.close();

{
ArrowCsvBatchProvider provider(file_path);
ArrowCsvBatchProvider provider(file_path, {"id1", "id2", "id3"}, 1);
EXPECT_EQ(provider.ReadNextBatch(),
std::vector<std::string>(
{"1,one,first", "2,two,second", "3,three,third"}));
std::vector<std::string>({"1,one,first"}));
EXPECT_EQ(provider.ReadNextBatch(),
std::vector<std::string>({"2,two,second"}));
EXPECT_EQ(provider.ReadNextBatch(),
std::vector<std::string>({"3,three,third"}));
EXPECT_EQ(provider.row_cnt(), 3);
EXPECT_TRUE(provider.ReadNextBatch().empty());
EXPECT_TRUE(provider.ReadNextBatch().empty());
EXPECT_EQ(provider.row_cnt(), 3);
}

{
ArrowCsvBatchProvider provider(file_path, {}, "#");
EXPECT_EQ(provider.ReadNextBatch(),
std::vector<std::string>(
{"1#one#first", "2#two#second", "3#three#third"}));
}

{
ArrowCsvBatchProvider provider(file_path, {"id2", "id1"});
ArrowCsvBatchProvider provider(file_path, {"id2", "id1"}, 3);
EXPECT_EQ(provider.ReadNextBatch(),
std::vector<std::string>({"one,1", "two,2", "three,3"}));
EXPECT_EQ(provider.row_cnt(), 3);
EXPECT_TRUE(provider.ReadNextBatch().empty());
EXPECT_TRUE(provider.ReadNextBatch().empty());
EXPECT_EQ(provider.row_cnt(), 3);
}

{
ArrowCsvBatchProvider provider(file_path, {"id3"});
ArrowCsvBatchProvider provider(file_path, {"id3"}, 5);
EXPECT_EQ(provider.ReadNextBatch(),
std::vector<std::string>({"first", "second", "third"}));
EXPECT_EQ(provider.row_cnt(), 3);
EXPECT_TRUE(provider.ReadNextBatch().empty());
EXPECT_TRUE(provider.ReadNextBatch().empty());
EXPECT_EQ(provider.row_cnt(), 3);
}

std::error_code ec;
Expand Down
4 changes: 2 additions & 2 deletions psi/psi/utils/hash_bucket_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "absl/strings/str_split.h"
#include "spdlog/spdlog.h"

#include "psi/psi/utils/batch_provider.h"
#include "psi/psi/utils/arrow_csv_batch_provider.h"

namespace psi::psi {

Expand Down Expand Up @@ -80,7 +80,7 @@ std::unique_ptr<HashBucketCache> CreateCacheFromCsv(
auto bucket_cache = std::make_unique<HashBucketCache>(cache_dir, bucket_num,
use_scoped_tmp_dir);

auto batch_provider = std::make_unique<CsvBatchProvider>(
auto batch_provider = std::make_unique<ArrowCsvBatchProvider>(
csv_path, schema_names, read_batch_size);
while (true) {
auto items = batch_provider->ReadNextBatch();
Expand Down

0 comments on commit 3ba6cca

Please sign in to comment.