Skip to content

Commit

Permalink
refactor spark split function
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyangxiaozhu committed Jun 19, 2024
1 parent cf6ae3a commit 6fb17bc
Show file tree
Hide file tree
Showing 2 changed files with 438 additions and 150 deletions.
194 changes: 114 additions & 80 deletions velox/functions/sparksql/SplitFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,91 +14,131 @@
* limitations under the License.
*/

#include <iostream>
#include <utility>

#include "velox/expression/DecodedArgs.h"
#include "velox/expression/VectorFunction.h"
#include "velox/expression/VectorWriters.h"
#include "velox/functions/lib/Re2Functions.h"

namespace facebook::velox::functions::sparksql {
namespace {

/// This class only implements the basic split version in which the pattern is a
/// single character
class SplitCharacter final : public exec::VectorFunction {
class Split final : public exec::VectorFunction {
public:
explicit SplitCharacter(const char pattern) : pattern_{pattern} {
static constexpr std::string_view kRegexChars = ".$|()[{^?*+\\";
VELOX_CHECK(
kRegexChars.find(pattern) == std::string::npos,
"This version of split supports single-length non-regex patterns");
}

Split() {}
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& result) const override {
exec::LocalDecodedVector input(context, *args[0], rows);

// Get the decoded vectors out of arguments.
const bool noLimit = (args.size() == 2);
exec::DecodedArgs decodedArgs(rows, args, context);
DecodedVector* strings = decodedArgs.at(0);
DecodedVector* delims = decodedArgs.at(1);
DecodedVector* limits = noLimit ? nullptr : decodedArgs.at(2);
BaseVector::ensureWritable(rows, ARRAY(VARCHAR()), context.pool(), result);
exec::VectorWriter<Array<Varchar>> resultWriter;
resultWriter.init(*result->as<ArrayVector>());

rows.applyToSelected([&](vector_size_t row) {
resultWriter.setOffset(row);
auto& arrayWriter = resultWriter.current();

const StringView& current = input->valueAt<StringView>(row);
const char* pos = current.begin();
const char* end = pos + current.size();
const char* delim;
do {
delim = std::find(pos, end, pattern_);
arrayWriter.add_item().setNoCopy(StringView(pos, delim - pos));
pos = delim + 1; // Skip past delim.
} while (delim != end);

resultWriter.commit();
});
int64_t limit = std::numeric_limits<int64_t>::max();
if (!noLimit) {
limit = limits->valueAt<int64_t>(0);
if (limit <= 0) {
limit = std::numeric_limits<int64_t>::max();
}
}
// Optimization for the (flat, const, const) case.
if (strings->isIdentityMapping() and delims->isConstantMapping() and
(noLimit or limits->isConstantMapping())) {
const auto* rawStrings = strings->data<StringView>();
const auto delim = delims->valueAt<StringView>(0);
rows.applyToSelected([&](vector_size_t row) {
applyInner(rawStrings[row], delim, limit, row, resultWriter);
});
} else {
// The rest of the cases are handled through this general path and no
// direct access.
rows.applyToSelected([&](vector_size_t row) {
applyInner(
strings->valueAt<StringView>(row),
delims->valueAt<StringView>(row),
limit,
row,
resultWriter);
});
}

resultWriter.finish();

// Reference the input StringBuffers since we did not deep copy above.
// Ensure that our result elements vector uses the same string buffer as
// the input vector of strings.
result->as<ArrayVector>()
->elements()
->as<FlatVector<StringView>>()
->acquireSharedStringBuffers(args[0].get());
->acquireSharedStringBuffers(strings->base());
}

private:
const char pattern_;
};

/// This class will be updated in the future as we support more variants of
/// split
class Split final : public exec::VectorFunction {
public:
Split() {}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& result) const override {
auto delimiterVector = args[1]->as<ConstantVector<StringView>>();
VELOX_CHECK(
delimiterVector, "Split function supports only constant delimiter");
auto patternString = args[1]->as<ConstantVector<StringView>>()->valueAt(0);
VELOX_CHECK_EQ(
patternString.size(),
1,
"split only supports only single-character pattern");
char pattern = patternString.data()[0];
SplitCharacter splitCharacter(pattern);
splitCharacter.apply(rows, args, nullptr, context, result);
inline void applyInner(
StringView input,
const StringView delim,
int64_t limit,
vector_size_t row,
exec::VectorWriter<Array<Varchar>>& resultWriter) const {
// Add new array (for the new row) to our array vector.
resultWriter.setOffset(row);
auto& arrayWriter = resultWriter.current();

// Trivial case of converting string to array with 1 element.
if (limit == 1) {
arrayWriter.add_item().setNoCopy(input);
resultWriter.commit();
return;
}

// We walk through our input cutting off the pieces using the delimiter and
// adding them to the elements vector, until we reached the end of the
// string or the limit.
int32_t addedElements{0};
auto* re = cache_.findOrCompile(delim);
const auto re2String = re2::StringPiece(input.data(), input.size());
size_t pos = 0;
const char* start = input.data();
re2::StringPiece subMatches[1];
while (re->Match(
re2String, pos, input.size(), RE2::Anchor::UNANCHORED, subMatches, 1)) {
const auto fullMatch = subMatches[0];
auto offset = fullMatch.data() - start;
const auto size = fullMatch.size();

if (size == 0) {
// delimer is empty string
offset += 1;
}

if (offset > input.size()) {
break;
}

arrayWriter.add_item().setNoCopy(
StringView(input.data() + pos, offset - pos));
pos = offset + size;
++addedElements;
// If the next element should be the last, leave the loop.
if (addedElements + 1 == limit) {
break;
}
}

// Add the rest of the string and we are done.
// Note, that the rest of the string can be empty - we still add it.
arrayWriter.add_item().setNoCopy(
StringView(input.data() + pos, input.size() - pos));
resultWriter.commit();
}

private:
mutable functions::detail::ReCache cache_;
};

/// The function returns specialized version of split based on the constant
Expand All @@ -109,31 +149,25 @@ std::shared_ptr<exec::VectorFunction> createSplit(
const std::string& /*name*/,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
BaseVector* constantPattern = inputArgs[1].constantValue.get();

if (inputArgs.size() > 3 || inputArgs[0].type->isVarchar() ||
inputArgs[1].type->isVarchar() || (constantPattern == nullptr)) {
return std::make_shared<Split>();
}
auto pattern = constantPattern->as<ConstantVector<StringView>>()->valueAt(0);
if (pattern.size() != 1) {
return std::make_shared<Split>();
}
char charPattern = pattern.data()[0];
// TODO: Add support for zero-length pattern, 2-character pattern
// TODO: add support for general regex pattern using R2
return std::make_shared<SplitCharacter>(charPattern);
return std::make_shared<Split>();
}

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures;
// varchar, varchar -> array(varchar)
return {exec::FunctionSignatureBuilder()
.returnType("array(varchar)")
.argumentType("varchar")
.constantArgumentType("varchar")
.build()};
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("array(varchar)")
.argumentType("varchar")
.constantArgumentType("varchar")
.build());
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("array(varchar)")
.argumentType("varchar")
.constantArgumentType("varchar")
.argumentType("bigint")
.build());
return signatures;
}

} // namespace

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
Expand Down
Loading

0 comments on commit 6fb17bc

Please sign in to comment.