diff --git a/doc_classes/nd.xml b/doc_classes/nd.xml index 74c6550..cc4c131 100644 --- a/doc_classes/nd.xml +++ b/doc_classes/nd.xml @@ -547,6 +547,18 @@ Equivalent to [code]nd.as_array(array, nd.DType.Int64)[/code]. + + + + + + + + + Returns a boolean array where two arrays are element-wise equal within a tolerance. + The tolerance values are positive, typically very small numbers. The relative difference (rtol * abs(b)) and the absolute difference atol are added together to compare against the absolute difference between a and b. + + diff --git a/docs/classes/class_nd.rst b/docs/classes/class_nd.rst index 99e871f..09f039a 100644 --- a/docs/classes/class_nd.rst +++ b/docs/classes/class_nd.rst @@ -158,6 +158,8 @@ Methods +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | :ref:`NDArray` | :ref:`int64`\ (\ array\: ``Variant``\ ) |static| | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | :ref:`NDArray` | :ref:`is_close`\ (\ a\: ``Variant``, b\: ``Variant``, rtol\: ``float`` = 1e-05, atol\: ``float`` = 1e-08, equal_nan\: ``bool`` = false\ ) |static| | + +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | :ref:`NDArray` | :ref:`less`\ (\ a\: ``Variant``, b\: ``Variant``\ ) |static| | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | :ref:`NDArray` | :ref:`less_equal`\ (\ a\: ``Variant``, b\: ``Variant``\ ) |static| | @@ -1323,6 +1325,20 @@ Equivalent to ``nd.as_array(array, nd.DType.Int64)``. ---- +.. _class_nd_method_is_close: + +.. rst-class:: classref-method + +:ref:`NDArray` **is_close**\ (\ a\: ``Variant``, b\: ``Variant``, rtol\: ``float`` = 1e-05, atol\: ``float`` = 1e-08, equal_nan\: ``bool`` = false\ ) |static| :ref:`🔗` + +Returns a boolean array where two arrays are element-wise equal within a tolerance. + +The tolerance values are positive, typically very small numbers. The relative difference (rtol \* abs(b)) and the absolute difference atol are added together to compare against the absolute difference between a and b. + +.. rst-class:: classref-item-separator + +---- + .. _class_nd_method_less: .. rst-class:: classref-method diff --git a/docs/setup/changelog.rst b/docs/setup/changelog.rst index 7dff96b..c9f7d37 100644 --- a/docs/setup/changelog.rst +++ b/docs/setup/changelog.rst @@ -25,7 +25,7 @@ Upcoming Changes (main branch) - Added bitwise functions (``bitwise_and``, ``bitwise_or``, ``bitwise_xor``, ``bitwise_not``, ``bitwise_left_shift``, ``bitwise_right_shift``). - Added matrix ``diagonal``, ``diag`` and ``trace`` functions. - Added ``transpose`` and ``flatten`` to ``NDArray`` methods. -- Added ``array_equal`` and ``all_close`` functions. +- Added ``is_close``, ``array_equal`` and ``all_close`` functions. **Changed** diff --git a/src/nd.cpp b/src/nd.cpp index bf90e52..09b7454 100644 --- a/src/nd.cpp +++ b/src/nd.cpp @@ -188,6 +188,7 @@ void nd::_bind_methods() { godot::ClassDB::bind_static_method("nd", D_METHOD("greater_equal", "a", "b"), &nd::greater_equal); godot::ClassDB::bind_static_method("nd", D_METHOD("less", "a", "b"), &nd::less); godot::ClassDB::bind_static_method("nd", D_METHOD("less_equal", "a", "b"), &nd::less_equal); + godot::ClassDB::bind_static_method("nd", D_METHOD("is_close", "a", "b", "rtol", "atol", "equal_nan"), &nd::is_close, DEFVAL(1e-05), DEFVAL(1e-08), DEFVAL(false)); godot::ClassDB::bind_static_method("nd", D_METHOD("logical_and", "a", "b"), &nd::logical_and); godot::ClassDB::bind_static_method("nd", D_METHOD("logical_or", "a", "b"), &nd::logical_or); @@ -1091,6 +1092,12 @@ Ref nd::less_equal(const Variant& a, const Variant& b) { return VARRAY_MAP2(less_equal, a, b); } +Ref nd::is_close(const Variant& a, const Variant& b, double_t rtol, double_t atol, bool equal_nan) { + return map_variants_as_arrays_with_target([rtol, atol, equal_nan](const va::VArrayTarget target, const std::shared_ptr& a, const std::shared_ptr& b) { + va::is_close(va::store::default_allocator, target, a->data, b->data, rtol, atol, equal_nan); + }, a, b); +} + Ref nd::logical_and(const Variant& a, const Variant& b) { return VARRAY_MAP2(logical_and, a, b); } diff --git a/src/nd.hpp b/src/nd.hpp index 4543b76..fb45fea 100644 --- a/src/nd.hpp +++ b/src/nd.hpp @@ -177,6 +177,7 @@ class nd : public Object { static Ref greater_equal(const Variant& a, const Variant& b); static Ref less(const Variant& a, const Variant& b); static Ref less_equal(const Variant& a, const Variant& b); + static Ref is_close(const Variant& a, const Variant& b, double_t rtol, double_t atol, bool equal_nan); // Logical. static Ref logical_and(const Variant& a, const Variant& b); diff --git a/src/vatensor/comparison.cpp b/src/vatensor/comparison.cpp index 81e9858..9257a8e 100644 --- a/src/vatensor/comparison.cpp +++ b/src/vatensor/comparison.cpp @@ -256,6 +256,37 @@ void va::less_equal(VStoreAllocator& allocator, VArrayTarget target, const VData ); } +template +void is_close(VStoreAllocator& allocator, VArrayTarget target, const A& a, const B& b, double rtol, double atol, bool equal_nan) { + va::xoperation_inplace< + Feature::is_close, + promote::common_in_nat_out + >( + [rtol, atol, equal_nan](auto&& a, auto&& b) { + return xt::isclose(std::forward(a), std::forward(b), rtol, atol, equal_nan); + }, + allocator, + target, + a, + b + ); +} + +void va::is_close(VStoreAllocator& allocator, VArrayTarget target, const VData& a, const VData& b, double rtol, double atol, bool equal_nan) { +#ifndef NUMDOT_DISABLE_SCALAR_OPTIMIZATION + if (va::dimension(a) == 0) { + ::is_close(allocator, target, b, va::to_single_value(a), rtol, atol, equal_nan); + return; + } + if (va::dimension(b) == 0) { + ::is_close(allocator, target, a, va::to_single_value(b), rtol, atol, equal_nan); + return; + } +#endif + + ::is_close(allocator, target, a, b, rtol, atol, equal_nan); +} + bool va::array_equal(const VData& a, const VData& b) { return va::vreduce< Feature::array_equal, diff --git a/src/vatensor/comparison.hpp b/src/vatensor/comparison.hpp index 9aee218..3d55287 100644 --- a/src/vatensor/comparison.hpp +++ b/src/vatensor/comparison.hpp @@ -10,6 +10,7 @@ namespace va { void greater_equal(VStoreAllocator& allocator, VArrayTarget target, const VData& a, const VData& b); void less(VStoreAllocator& allocator, VArrayTarget target, const VData& a, const VData& b); void less_equal(VStoreAllocator& allocator, VArrayTarget target, const VData& a, const VData& b); + void is_close(VStoreAllocator& allocator, VArrayTarget target, const VData& a, const VData& b, double rtol = 1e-05, double atol = 1e-08, bool equal_nan = false); bool array_equal(const VData& a, const VData& b); bool all_close(const VData& a, const VData& b, double rtol = 1e-05, double atol = 1e-08, bool equal_nan = false); diff --git a/src/vatensor/vfeature.hpp b/src/vatensor/vfeature.hpp index 3337bac..11628d4 100644 --- a/src/vatensor/vfeature.hpp +++ b/src/vatensor/vfeature.hpp @@ -27,6 +27,7 @@ namespace va { greater_equal, less, less_equal, + is_close, array_equal, all_close,