diff --git a/autoray/autoray.py b/autoray/autoray.py index 0081306..d3c6a85 100644 --- a/autoray/autoray.py +++ b/autoray/autoray.py @@ -1865,7 +1865,18 @@ def numpy_like(ary, indices_or_sections, axis=0, **kwargs): return numpy_like +def tensorflow_diag(x): + nd = ndim(x) + if nd == 2: + return do("linalg.diag_part", x) + elif nd == 1: + return do("linalg.diag", x) + else: + raise ValueError("Input must be 1- or 2-d.") + + _FUNCS["tensorflow", "to_numpy"] = tensorflow_to_numpy +_FUNCS["tensorflow", "diag"] = tensorflow_diag _SUBMODULE_ALIASES["tensorflow", "log"] = "tensorflow.math" _SUBMODULE_ALIASES["tensorflow", "conj"] = "tensorflow.math" @@ -1873,7 +1884,6 @@ def numpy_like(ary, indices_or_sections, axis=0, **kwargs): _SUBMODULE_ALIASES["tensorflow", "imag"] = "tensorflow.math" _SUBMODULE_ALIASES["tensorflow", "power"] = "tensorflow.math" _SUBMODULE_ALIASES["tensorflow", "count_nonzero"] = "tensorflow.math" -_SUBMODULE_ALIASES["tensorflow", "diag"] = "tensorflow.linalg" _SUBMODULE_ALIASES["tensorflow", "trace"] = "tensorflow.linalg" _SUBMODULE_ALIASES["tensorflow", "tril"] = "tensorflow.linalg" _SUBMODULE_ALIASES["tensorflow", "triu"] = "tensorflow.linalg" @@ -1888,7 +1898,6 @@ def numpy_like(ary, indices_or_sections, axis=0, **kwargs): _FUNC_ALIASES["tensorflow", "arange"] = "range" _FUNC_ALIASES["tensorflow", "tril"] = "band_part" _FUNC_ALIASES["tensorflow", "triu"] = "band_part" -_FUNC_ALIASES["tensorflow", "diag"] = "tensor_diag" _FUNC_ALIASES["tensorflow", "array"] = "convert_to_tensor" _FUNC_ALIASES["tensorflow", "astype"] = "cast" _FUNC_ALIASES["tensorflow", "power"] = "pow"