From 2c7704922fc4ef8d10ec96da98b6172d9fcd54d5 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Thu, 9 May 2024 18:05:57 -0700 Subject: [PATCH] fix tensorflow diag --- autoray/autoray.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/autoray/autoray.py b/autoray/autoray.py index 0081306..d3c6a85 100644 --- a/autoray/autoray.py +++ b/autoray/autoray.py @@ -1865,7 +1865,18 @@ def numpy_like(ary, indices_or_sections, axis=0, **kwargs): return numpy_like +def tensorflow_diag(x): + nd = ndim(x) + if nd == 2: + return do("linalg.diag_part", x) + elif nd == 1: + return do("linalg.diag", x) + else: + raise ValueError("Input must be 1- or 2-d.") + + _FUNCS["tensorflow", "to_numpy"] = tensorflow_to_numpy +_FUNCS["tensorflow", "diag"] = tensorflow_diag _SUBMODULE_ALIASES["tensorflow", "log"] = "tensorflow.math" _SUBMODULE_ALIASES["tensorflow", "conj"] = "tensorflow.math" @@ -1873,7 +1884,6 @@ def numpy_like(ary, indices_or_sections, axis=0, **kwargs): _SUBMODULE_ALIASES["tensorflow", "imag"] = "tensorflow.math" _SUBMODULE_ALIASES["tensorflow", "power"] = "tensorflow.math" _SUBMODULE_ALIASES["tensorflow", "count_nonzero"] = "tensorflow.math" -_SUBMODULE_ALIASES["tensorflow", "diag"] = "tensorflow.linalg" _SUBMODULE_ALIASES["tensorflow", "trace"] = "tensorflow.linalg" _SUBMODULE_ALIASES["tensorflow", "tril"] = "tensorflow.linalg" _SUBMODULE_ALIASES["tensorflow", "triu"] = "tensorflow.linalg" @@ -1888,7 +1898,6 @@ def numpy_like(ary, indices_or_sections, axis=0, **kwargs): _FUNC_ALIASES["tensorflow", "arange"] = "range" _FUNC_ALIASES["tensorflow", "tril"] = "band_part" _FUNC_ALIASES["tensorflow", "triu"] = "band_part" -_FUNC_ALIASES["tensorflow", "diag"] = "tensor_diag" _FUNC_ALIASES["tensorflow", "array"] = "convert_to_tensor" _FUNC_ALIASES["tensorflow", "astype"] = "cast" _FUNC_ALIASES["tensorflow", "power"] = "pow"