diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 7b558794d5a2..ae9af35be653 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import re from jax._src.lib import xla_client as xc @@ -33,8 +34,7 @@ class SpecifiedLayout(XLACompatibleLayout): def __init__(self, layout: xc.Layout): self._layout = layout - self._layout_str = self._layout.to_string() - self._minor_to_major = self._layout.minor_to_major() + self._layout_str = str(self._layout) def __repr__(self): return f'SpecifiedLayout({self._layout_str})' @@ -50,6 +50,15 @@ def __eq__(self, other): def _to_xla_layout(self) -> str: return self._layout_str + @property + def _minor_to_major(self): + m = re.search("{([0-9,]*):", str(self)) + assert m is not None + m2m_str = m.group(1) + if m2m_str == '': + return () + return tuple(int(x) for x in m2m_str.split(",")) + class LayoutRequest: