Skip to content

Commit

Permalink
Merge pull request #1116 from Bears-R-Us/str-compare
Browse files Browse the repository at this point in the history
Use `computeOnSegments` for string comparisons
  • Loading branch information
reuster986 authored Feb 17, 2022
2 parents e43c9c7 + 553b972 commit 929f788
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 36 deletions.
68 changes: 33 additions & 35 deletions src/SegmentedArray.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,11 @@ module SegmentedArray {
moduleName = getModuleName(),
errorClass="ArgumentError");
}
if useHash {
const lh = lss.hash();
const rh = rss.hash();
return if polarity then (lh == rh) else (lh != rh);
}
ref oD = lss.offsets.aD;
// Start by assuming all elements differ, then correct for those that are equal
// This translates to an initial value of false for == and true for !=
Expand Down Expand Up @@ -1091,53 +1096,46 @@ module SegmentedArray {

/* Test an array of strings for equality against a constant string. Return a boolean
vector the same size as the array. */
operator ==(ss:SegString, testStr: string) {
return compare(ss, testStr, true);
operator ==(ss:SegString, testStr: string) throws {
return compare(ss, testStr, SegFunction.StringCompareLiteralEq);
}

/* Test an array of strings for inequality against a constant string. Return a boolean
vector the same size as the array. */
operator !=(ss:SegString, testStr: string) {
return compare(ss, testStr, false);
operator !=(ss:SegString, testStr: string) throws {
return compare(ss, testStr, SegFunction.StringCompareLiteralNeq);
}

inline proc stringCompareLiteralEq(values, rng, testStr) {
if rng.size == (testStr.numBytes + 1) {
const s = interpretAsString(values, rng);
return (s == testStr);
} else {
return false;
}
}

inline proc stringCompareLiteralNeq(values, rng, testStr) {
if rng.size == (testStr.numBytes + 1) {
const s = interpretAsString(values, rng);
return (s != testStr);
} else {
return true;
}
}

/* Element-wise comparison of an arrays of string against a target string.
The polarity parameter determines whether the comparison checks for
equality (polarity=true, result is true where elements equal target)
or inequality (polarity=false, result is true where elements differ from
target). */
proc compare(ss:SegString, testStr: string, param polarity: bool) {
ref oD = ss.offsets.aD;
// Initially assume all elements equal the target string, then correct errors
// For ==, this means everything starts true; for !=, everything starts false
var truth: [oD] bool = polarity;
// Early exit for zero-length result
if (ss.size == 0) {
return truth;
}
ref values = ss.values.a;
ref vD = ss.values.aD;
ref offsets = ss.offsets.a;
// Use a whole-array strategy, where the ith byte from every segment is checked simultaneously
// This will do len(testStr) parallel loops, but loops will have low overhead
for (b, i) in zip(testStr.chpl_bytes(), 0..) {
forall (t, o, idx) in zip(truth, offsets, oD) with (var agg = newDstAggregator(bool)) {
if ((o+i > vD.high) || (b != values[o+i])) {
// Strings are not equal, so change the output
// For ==, output is now false; for !=, output is now true
agg.copy(t, !polarity);
}
}
}
// Check the length by checking that the next byte is null
forall (t, o, idx) in zip(truth, offsets, oD) with (var agg = newDstAggregator(bool)) {
if ((o+testStr.size > vD.high) || (0 != values[o+testStr.size])) {
// Strings are not equal, so change the output
// For ==, output is now false; for !=, output is now true
agg.copy(t, !polarity);
}
proc compare(ss:SegString, const testStr: string, param function: SegFunction) throws {
if testStr.numBytes == 0 {
// Comparing against the empty string is a quick check for zero length
const lengths = ss.getLengths() - 1;
return (lengths == 0);
}
return truth;
return computeOnSegments(ss.offsets.a, ss.values.a, function, bool, testStr);
}

private config const in1dAssocSortThreshold = 10**6;
Expand Down
10 changes: 9 additions & 1 deletion src/SegmentedComputation.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ module SegmentedComputation {
StringToNumericStrict,
StringToNumericIgnore,
StringToNumericReturnValidity,
StringCompareLiteralEq,
StringCompareLiteralNeq,
}

proc computeOnSegments(segments: [?D] int, values: [?vD] ?t, param function: SegFunction, type retType) throws {
proc computeOnSegments(segments: [?D] int, values: [?vD] ?t, param function: SegFunction, type retType, const strArg: string = "") throws {
// type retType = if (function == SegFunction.StringToNumericReturnValidity) then (outType, bool) else outType;
var res: [D] retType;
if (D.size == 0) {
Expand Down Expand Up @@ -82,6 +84,12 @@ module SegmentedComputation {
when SegFunction.StringToNumericReturnValidity {
agg.copy(res[i], stringToNumericReturnValidity(values, start..#len, retType[0]));
}
when SegFunction.StringCompareLiteralEq {
agg.copy(res[i], stringCompareLiteralEq(values, start..#len, strArg));
}
when SegFunction.StringCompareLiteralNeq {
agg.copy(res[i], stringCompareLiteralNeq(values, start..#len, strArg));
}
otherwise {
compilerError("Unrecognized segmented function");
}
Expand Down

0 comments on commit 929f788

Please sign in to comment.