Skip to content

Commit

Permalink
Element accessor: Allow -ve indexes like Python
Browse files Browse the repository at this point in the history
  • Loading branch information
shahramn committed Nov 26, 2023
1 parent 8804b33 commit 20ff3ae
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
53 changes: 40 additions & 13 deletions src/grib_accessor_class_element.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,24 @@ static void init(grib_accessor* a, const long l, grib_arguments* c)
self->element = grib_arguments_get_long(hand, c, n++);
}

static int check_element_index(const char* func, const char* array_name, long index, size_t size)
{
const grib_context* c = grib_context_get_default();
if (index < 0 || index >= size) {
grib_context_log(c, GRIB_LOG_ERROR, "%s: Invalid element index %ld for array '%s'. Value must be between 0 and %zu",
func, index, array_name, size - 1);
return GRIB_INVALID_ARGUMENT;
}
return GRIB_SUCCESS;
}

static int unpack_long(grib_accessor* a, long* val, size_t* len)
{
grib_accessor_element* self = (grib_accessor_element*)a;
int ret = 0;
size_t size = 0;
long* ar = NULL;
grib_context* c = a->context;
const grib_context* c = a->context;
grib_handle* hand = grib_handle_of_accessor(a);

if (*len < 1) {
Expand All @@ -140,10 +151,12 @@ static int unpack_long(grib_accessor* a, long* val, size_t* len)
if ((ret = grib_get_long_array_internal(hand, self->array, ar, &size)) != GRIB_SUCCESS)
return ret;

if (self->element < 0 || self->element >= size) {
grib_context_log(c, GRIB_LOG_ERROR, "Invalid element %ld for array '%s'. Value must be between 0 and %zu",
self->element, self->array, size - 1);
ret = GRIB_INVALID_ARGUMENT;
// An index of -x means the xth item from the end of the list, so ar[-1] means the last item in ar
if (self->element < 0) {
self->element = size + self->element;
}

if ((ret = check_element_index(__func__, self->array, self->element, size)) != GRIB_SUCCESS) {
goto the_end;
}

Expand All @@ -160,7 +173,7 @@ static int pack_long(grib_accessor* a, const long* val, size_t* len)
int ret = 0;
size_t size = 0;
long* ar = NULL;
grib_context* c = a->context;
const grib_context* c = a->context;
grib_handle* hand = grib_handle_of_accessor(a);

if (*len < 1) {
Expand All @@ -180,11 +193,23 @@ static int pack_long(grib_accessor* a, const long* val, size_t* len)
if ((ret = grib_get_long_array_internal(hand, self->array, ar, &size)) != GRIB_SUCCESS)
return ret;

// An index of -x means the xth item from the end of the list, so ar[-1] means the last item in ar
if (self->element < 0) {
self->element = size + self->element;
}

if ((ret = check_element_index(__func__, self->array, self->element, size)) != GRIB_SUCCESS) {
goto the_end;
}

Assert(self->element >= 0);
Assert(self->element < size);
ar[self->element] = *val;

if ((ret = grib_set_long_array_internal(hand, self->array, ar, size)) != GRIB_SUCCESS)
return ret;
goto the_end;

the_end:
grib_context_free(c, ar);
return ret;
}
Expand All @@ -195,8 +220,8 @@ static int unpack_double(grib_accessor* a, double* val, size_t* len)
int ret = 0;
size_t size = 0;
double* ar = NULL;
grib_context* c = a->context;
grib_handle* hand = grib_handle_of_accessor(a);
const grib_context* c = a->context;
const grib_handle* hand = grib_handle_of_accessor(a);

if (*len < 1) {
ret = GRIB_ARRAY_TOO_SMALL;
Expand All @@ -215,10 +240,12 @@ static int unpack_double(grib_accessor* a, double* val, size_t* len)
if ((ret = grib_get_double_array_internal(hand, self->array, ar, &size)) != GRIB_SUCCESS)
return ret;

if (self->element < 0 || self->element >= size) {
grib_context_log(c, GRIB_LOG_ERROR, "Invalid element %ld for array '%s'. Value must be between 0 and %zu",
self->element, self->array, size - 1);
ret = GRIB_INVALID_ARGUMENT;
// An index of -x means the xth item from the end of the list, so ar[-1] means the last item in ar
if (self->element < 0) {
self->element = size + self->element;
}

if ((ret = check_element_index(__func__, self->array, self->element, size)) != GRIB_SUCCESS) {
goto the_end;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/grib_ecc-1406.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ cat > $tempFilt <<EOF
set iDirectionIncrement = 10000;
set jDirectionIncrement = 10000;
meta lastVal element(values, numberOfValues - 1);
meta lastVal element(values, -1); # Like Python
set lastVal = 42;
write;
Expand Down
21 changes: 17 additions & 4 deletions tests/grib_element.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,28 @@ cat > $tempFilt <<EOF
meta elemA element(pl, Nj - 3);
meta elemB element(pl, Nj - 2);
meta elemC element(pl, Nj - 1);
print "elemA=[elemA], elemB=[elemB], elemC=[elemC]";
meta elemZ element(pl, -1); # another way of getting the last element
print "elemA=[elemA], elemB=[elemB], elemC=[elemC], elemZ=[elemZ]";
EOF
${tools_dir}/grib_filter $tempFilt $input > $tempText
echo "elemA=36, elemB=25, elemC=20" > $tempRef
echo "elemA=36, elemB=25, elemC=20, elemZ=20" > $tempRef
diff $tempRef $tempText


# Invalid element
# Invalid element indexes
cat > $tempFilt <<EOF
meta badElem element(pl, -1);
meta badElem element(pl, -97);
print "[badElem]";
EOF
set +e
${tools_dir}/grib_filter $tempFilt $input > $tempText 2>&1
status=$?
set -e
[ $status -ne 0 ]
grep -q "Invalid element.*Value must be between 0 and 95" $tempText

cat > $tempFilt <<EOF
meta badElem element(pl, 197);
print "[badElem]";
EOF
set +e
Expand All @@ -42,4 +54,5 @@ set -e
grep -q "Invalid element.*Value must be between 0 and 95" $tempText


# Clean up
rm -f $tempRef $tempText $tempFilt

0 comments on commit 20ff3ae

Please sign in to comment.