Skip to content

Commit

Permalink
Add is_close function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivorforce committed Nov 11, 2024
1 parent deebb07 commit 86427fd
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 1 deletion.
12 changes: 12 additions & 0 deletions doc_classes/nd.xml
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,18 @@
Equivalent to [code]nd.as_array(array, nd.DType.Int64)[/code].
</description>
</method>
<method name="is_close" qualifiers="static">
<return type="NDArray" />
<param index="0" name="a" type="Variant" />
<param index="1" name="b" type="Variant" />
<param index="2" name="rtol" type="float" default="1e-05" />
<param index="3" name="atol" type="float" default="1e-08" />
<param index="4" name="equal_nan" type="bool" default="false" />
<description>
Returns a boolean array where two arrays are element-wise equal within a tolerance.
The tolerance values are positive, typically very small numbers. The relative difference (rtol * abs(b)) and the absolute difference atol are added together to compare against the absolute difference between a and b.
</description>
</method>
<method name="less" qualifiers="static">
<return type="NDArray" />
<param index="0" name="a" type="Variant" />
Expand Down
16 changes: 16 additions & 0 deletions docs/classes/class_nd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ Methods
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`NDArray<class_NDArray>` | :ref:`int64<class_nd_method_int64>`\ (\ array\: ``Variant``\ ) |static| |
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`NDArray<class_NDArray>` | :ref:`is_close<class_nd_method_is_close>`\ (\ a\: ``Variant``, b\: ``Variant``, rtol\: ``float`` = 1e-05, atol\: ``float`` = 1e-08, equal_nan\: ``bool`` = false\ ) |static| |
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`NDArray<class_NDArray>` | :ref:`less<class_nd_method_less>`\ (\ a\: ``Variant``, b\: ``Variant``\ ) |static| |
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`NDArray<class_NDArray>` | :ref:`less_equal<class_nd_method_less_equal>`\ (\ a\: ``Variant``, b\: ``Variant``\ ) |static| |
Expand Down Expand Up @@ -1323,6 +1325,20 @@ Equivalent to ``nd.as_array(array, nd.DType.Int64)``.

----

.. _class_nd_method_is_close:

.. rst-class:: classref-method

:ref:`NDArray<class_NDArray>` **is_close**\ (\ a\: ``Variant``, b\: ``Variant``, rtol\: ``float`` = 1e-05, atol\: ``float`` = 1e-08, equal_nan\: ``bool`` = false\ ) |static| :ref:`🔗<class_nd_method_is_close>`

Returns a boolean array where two arrays are element-wise equal within a tolerance.

The tolerance values are positive, typically very small numbers. The relative difference (rtol \* abs(b)) and the absolute difference atol are added together to compare against the absolute difference between a and b.

.. rst-class:: classref-item-separator

----

.. _class_nd_method_less:

.. rst-class:: classref-method
Expand Down
2 changes: 1 addition & 1 deletion docs/setup/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Upcoming Changes (main branch)
- Added bitwise functions (``bitwise_and``, ``bitwise_or``, ``bitwise_xor``, ``bitwise_not``, ``bitwise_left_shift``, ``bitwise_right_shift``).
- Added matrix ``diagonal``, ``diag`` and ``trace`` functions.
- Added ``transpose`` and ``flatten`` to ``NDArray`` methods.
- Added ``array_equal`` and ``all_close`` functions.
- Added ``is_close``, ``array_equal`` and ``all_close`` functions.

**Changed**

Expand Down
7 changes: 7 additions & 0 deletions src/nd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ void nd::_bind_methods() {
godot::ClassDB::bind_static_method("nd", D_METHOD("greater_equal", "a", "b"), &nd::greater_equal);
godot::ClassDB::bind_static_method("nd", D_METHOD("less", "a", "b"), &nd::less);
godot::ClassDB::bind_static_method("nd", D_METHOD("less_equal", "a", "b"), &nd::less_equal);
godot::ClassDB::bind_static_method("nd", D_METHOD("is_close", "a", "b", "rtol", "atol", "equal_nan"), &nd::is_close, DEFVAL(1e-05), DEFVAL(1e-08), DEFVAL(false));

godot::ClassDB::bind_static_method("nd", D_METHOD("logical_and", "a", "b"), &nd::logical_and);
godot::ClassDB::bind_static_method("nd", D_METHOD("logical_or", "a", "b"), &nd::logical_or);
Expand Down Expand Up @@ -1091,6 +1092,12 @@ Ref<NDArray> nd::less_equal(const Variant& a, const Variant& b) {
return VARRAY_MAP2(less_equal, a, b);
}

Ref<NDArray> nd::is_close(const Variant& a, const Variant& b, double_t rtol, double_t atol, bool equal_nan) {
return map_variants_as_arrays_with_target([rtol, atol, equal_nan](const va::VArrayTarget target, const std::shared_ptr<va::VArray>& a, const std::shared_ptr<va::VArray>& b) {
va::is_close(va::store::default_allocator, target, a->data, b->data, rtol, atol, equal_nan);
}, a, b);
}

Ref<NDArray> nd::logical_and(const Variant& a, const Variant& b) {
return VARRAY_MAP2(logical_and, a, b);
}
Expand Down
1 change: 1 addition & 0 deletions src/nd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class nd : public Object {
static Ref<NDArray> greater_equal(const Variant& a, const Variant& b);
static Ref<NDArray> less(const Variant& a, const Variant& b);
static Ref<NDArray> less_equal(const Variant& a, const Variant& b);
static Ref<NDArray> is_close(const Variant& a, const Variant& b, double_t rtol, double_t atol, bool equal_nan);

// Logical.
static Ref<NDArray> logical_and(const Variant& a, const Variant& b);
Expand Down
31 changes: 31 additions & 0 deletions src/vatensor/comparison.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,37 @@ void va::less_equal(VStoreAllocator& allocator, VArrayTarget target, const VData
);
}

template <typename A, typename B>
void is_close(VStoreAllocator& allocator, VArrayTarget target, const A& a, const B& b, double rtol, double atol, bool equal_nan) {
va::xoperation_inplace<
Feature::is_close,
promote::common_in_nat_out
>(
[rtol, atol, equal_nan](auto&& a, auto&& b) {
return xt::isclose(std::forward<decltype(a)>(a), std::forward<decltype(b)>(b), rtol, atol, equal_nan);
},
allocator,
target,
a,
b
);
}

void va::is_close(VStoreAllocator& allocator, VArrayTarget target, const VData& a, const VData& b, double rtol, double atol, bool equal_nan) {
#ifndef NUMDOT_DISABLE_SCALAR_OPTIMIZATION
if (va::dimension(a) == 0) {
::is_close(allocator, target, b, va::to_single_value(a), rtol, atol, equal_nan);
return;
}
if (va::dimension(b) == 0) {
::is_close(allocator, target, a, va::to_single_value(b), rtol, atol, equal_nan);
return;
}
#endif

::is_close(allocator, target, a, b, rtol, atol, equal_nan);
}

bool va::array_equal(const VData& a, const VData& b) {
return va::vreduce<
Feature::array_equal,
Expand Down
1 change: 1 addition & 0 deletions src/vatensor/comparison.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace va {
void greater_equal(VStoreAllocator& allocator, VArrayTarget target, const VData& a, const VData& b);
void less(VStoreAllocator& allocator, VArrayTarget target, const VData& a, const VData& b);
void less_equal(VStoreAllocator& allocator, VArrayTarget target, const VData& a, const VData& b);
void is_close(VStoreAllocator& allocator, VArrayTarget target, const VData& a, const VData& b, double rtol = 1e-05, double atol = 1e-08, bool equal_nan = false);

bool array_equal(const VData& a, const VData& b);
bool all_close(const VData& a, const VData& b, double rtol = 1e-05, double atol = 1e-08, bool equal_nan = false);
Expand Down
1 change: 1 addition & 0 deletions src/vatensor/vfeature.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace va {
greater_equal,
less,
less_equal,
is_close,
array_equal,
all_close,

Expand Down

0 comments on commit 86427fd

Please sign in to comment.