Skip to content

Commit

Permalink
fix tensorflow diag
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed May 10, 2024
1 parent 3e04dbb commit 2c77049
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,15 +1865,25 @@ def numpy_like(ary, indices_or_sections, axis=0, **kwargs):
return numpy_like


def tensorflow_diag(x):
nd = ndim(x)
if nd == 2:
return do("linalg.diag_part", x)
elif nd == 1:
return do("linalg.diag", x)
else:
raise ValueError("Input must be 1- or 2-d.")


_FUNCS["tensorflow", "to_numpy"] = tensorflow_to_numpy
_FUNCS["tensorflow", "diag"] = tensorflow_diag

_SUBMODULE_ALIASES["tensorflow", "log"] = "tensorflow.math"
_SUBMODULE_ALIASES["tensorflow", "conj"] = "tensorflow.math"
_SUBMODULE_ALIASES["tensorflow", "real"] = "tensorflow.math"
_SUBMODULE_ALIASES["tensorflow", "imag"] = "tensorflow.math"
_SUBMODULE_ALIASES["tensorflow", "power"] = "tensorflow.math"
_SUBMODULE_ALIASES["tensorflow", "count_nonzero"] = "tensorflow.math"
_SUBMODULE_ALIASES["tensorflow", "diag"] = "tensorflow.linalg"
_SUBMODULE_ALIASES["tensorflow", "trace"] = "tensorflow.linalg"
_SUBMODULE_ALIASES["tensorflow", "tril"] = "tensorflow.linalg"
_SUBMODULE_ALIASES["tensorflow", "triu"] = "tensorflow.linalg"
Expand All @@ -1888,7 +1898,6 @@ def numpy_like(ary, indices_or_sections, axis=0, **kwargs):
_FUNC_ALIASES["tensorflow", "arange"] = "range"
_FUNC_ALIASES["tensorflow", "tril"] = "band_part"
_FUNC_ALIASES["tensorflow", "triu"] = "band_part"
_FUNC_ALIASES["tensorflow", "diag"] = "tensor_diag"
_FUNC_ALIASES["tensorflow", "array"] = "convert_to_tensor"
_FUNC_ALIASES["tensorflow", "astype"] = "cast"
_FUNC_ALIASES["tensorflow", "power"] = "pow"
Expand Down

0 comments on commit 2c77049

Please sign in to comment.