Skip to content

Commit

Permalink
only pv with package
Browse files Browse the repository at this point in the history
  • Loading branch information
gobbleturk committed Nov 7, 2023
1 parent ea4cf93 commit eb7120c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions MaxText/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def __call__(self, inputs: Array) -> Array:

contract_ind = tuple(range(0, len(axis)))

if self.never_quantize or not cfg.int8_training:
if True or self.never_quantize or not cfg.int8_training:
return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
else:
aqt_key = self.make_rng('aqt')
Expand Down Expand Up @@ -688,7 +688,7 @@ def attend(self, query: Array) -> Array:
in NLP models.
"""
dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype
if not self.config.int8_training:
if True or not self.config.int8_training:
return maxtext_dot(query, jnp.asarray(self.embedding, dtype).T)
else:
aqt_cfg = get_aqt_cfg(self.config)
Expand Down
2 changes: 1 addition & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
XLA_DUMP_DIR="/tmp/xla_dumps"
os.environ["XLA_FLAGS"]=f"--xla_dump_to={XLA_DUMP_DIR}"
#os.environ["XLA_FLAGS"]=f"--xla_dump_to={XLA_DUMP_DIR}"
print(f"Found {jax.device_count()} devices.")

from typing import Sequence
Expand Down

0 comments on commit eb7120c

Please sign in to comment.