From eb7120c65f4eb3d9dd104481685b3ea6abc383b1 Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Tue, 7 Nov 2023 01:39:19 +0000 Subject: [PATCH] only pv with package --- MaxText/layers.py | 4 ++-- MaxText/train.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/MaxText/layers.py b/MaxText/layers.py index 18485c8d8..902748285 100644 --- a/MaxText/layers.py +++ b/MaxText/layers.py @@ -257,7 +257,7 @@ def __call__(self, inputs: Array) -> Array: contract_ind = tuple(range(0, len(axis))) - if self.never_quantize or not cfg.int8_training: + if True or self.never_quantize or not cfg.int8_training: return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) else: aqt_key = self.make_rng('aqt') @@ -688,7 +688,7 @@ def attend(self, query: Array) -> Array: in NLP models. """ dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype - if not self.config.int8_training: + if True or not self.config.int8_training: return maxtext_dot(query, jnp.asarray(self.embedding, dtype).T) else: aqt_cfg = get_aqt_cfg(self.config) diff --git a/MaxText/train.py b/MaxText/train.py index bed22e032..fd844cbbe 100755 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -27,7 +27,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" XLA_DUMP_DIR="/tmp/xla_dumps" -os.environ["XLA_FLAGS"]=f"--xla_dump_to={XLA_DUMP_DIR}" +#os.environ["XLA_FLAGS"]=f"--xla_dump_to={XLA_DUMP_DIR}" print(f"Found {jax.device_count()} devices.") from typing import Sequence