Replies: 2 comments
-
Hi - thanks for the question! Unlike CPUs and some GPUs, TPUs have very little support for 64-bit operations, and typically you'll have the best experience on TPUs if you leave |
Beta Was this translation helpful? Give feedback.
0 replies
-
Thanks Jake! I |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Dear List:
It is a basic question about precision, and many similar queries existed in the forum, and sorry for similar query;
When I ran the following code in default precision( f32) in TPU-V2 it works fine:
But when,
jax.config.update("jax_enable_x64", True),
is uncomment it shows error:Traceback (most recent call last):
File "newPrecision.py", line 11, in
jnp.linalg.inv(a)
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
compiled = _pjit_lower(
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/gpkmohan_mbcet/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: Only F32 and C64 types are implemented in LuDecomposition; got shape f64[2,2]: @ 0x7f66f6de03ff (unknown)
@ 0x7f66f1e89d81 (unknown)
@ 0x7f66f552cea4 (unknown)
@ 0x7f66ef7e8004 (unknown)
@ 0x7f66f6e1da41 (unknown)
@ 0x7f66f6e1be9d (unknown)
@ 0x7f66f6e1b7ae (unknown)
@ 0x7f66f1d5d859 (unknown)
@ 0x7f66f1d6fae0 (unknown)
@ 0x7f66f1d808dd (unknown)
@ 0x7f66f1d7fe2b (unknown)
@ 0x7f66f029cdd0 (unknown)
@ 0x7f66f02995be (unknown)
@ 0x7f66f029857f (unknown)
@ 0x7f66f163c136 (unknown)
@ 0x7f66f0240513 (unknown)
@ 0x7f66f01e3dbd (unknown)
@ 0x7f66f01dad1d (unknown)
@ 0x7f67999ee484 xla::InitializeArgsAndCompile()
@ 0x7f67999eea1a xla::PjRtCApiClient::Compile()
@ 0x7f679c0d9791 xla::ifrt::PjRtLoadedExecutable::Create()
@ 0x7f679c0d005c xla::ifrt::PjRtCompiler::Compile()
@ 0x7f6799990936 xla::PyClient::Compile()
@ 0x7f67996718f7 pybind11::detail::argument_loader<>::call_impl<>()
@ 0x7f67996720b0 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
@ 0x7f6799638cbf pybind11::cpp_function::dispatcher()
@ 0x5e66b9 PyCFunction_Call
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "newPrecision.py", line 11, in
jnp.linalg.inv(a)
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: Only F32 and C64 types are implemented in LuDecomposition; got shape f64[2,2]: @ 0x7f66f6de03ff (unknown)
@ 0x7f66f1e89d81 (unknown)
@ 0x7f66f552cea4 (unknown)
@ 0x7f66ef7e8004 (unknown)
@ 0x7f66f6e1da41 (unknown)
@ 0x7f66f6e1be9d (unknown)
@ 0x7f66f6e1b7ae (unknown)
@ 0x7f66f1d5d859 (unknown)
@ 0x7f66f1d6fae0 (unknown)
@ 0x7f66f1d808dd (unknown)
@ 0x7f66f1d7fe2b (unknown)
@ 0x7f66f029cdd0 (unknown)
@ 0x7f66f02995be (unknown)
@ 0x7f66f029857f (unknown)
@ 0x7f66f163c136 (unknown)
@ 0x7f66f0240513 (unknown)
@ 0x7f66f01e3dbd (unknown)
@ 0x7f66f01dad1d (unknown)
@ 0x7f67999ee484 xla::InitializeArgsAndCompile()
@ 0x7f67999eea1a xla::PjRtCApiClient::Compile()
@ 0x7f679c0d9791 xla::ifrt::PjRtLoadedExecutable::Create()
@ 0x7f679c0d005c xla::ifrt::PjRtCompiler::Compile()
@ 0x7f6799990936 xla::PyClient::Compile()
@ 0x7f67996718f7 pybind11::detail::argument_loader<>::call_impl<>()
@ 0x7f67996720b0 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
@ 0x7f6799638cbf pybind11::cpp_function::dispatcher()
@ 0x5e66b9 PyCFunction_Call
it clearly indicates, "...Only F32 and C64 types are implemented in LuDecomposition " (perhaps only for TPUs and not for GPUs/CPUs, I guess and I also think (LUdecomp) is used in jnp's inv function.
Is any solutions existed to calculate inverse of a matrix in Jax with f64 precision (that doesn't use f32 bound LUdecomp etc.)? Or any other GIT libraries are there to do this (especially for TPUs) with this precision?
Thanks in advance
Kiran
Beta Was this translation helpful? Give feedback.
All reactions