Skip to content

Commit

Permalink
Pre-commit config
Browse files Browse the repository at this point in the history
  • Loading branch information
khatwanimohit committed Apr 23, 2024
1 parent 718d9e7 commit ad072dd
Show file tree
Hide file tree
Showing 24 changed files with 42 additions and 31 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ on:
- cron: '0 */2 * * *'

jobs:
# IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'gpu' job
# IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODIFICATIONS TO 'gpu' job
tpu:
strategy:
fail-fast: false
Expand Down Expand Up @@ -99,7 +99,7 @@ jobs:
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \
'python3 pedagogical_examples/shmap_collective_matmul.py'
# IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'tpu' job
# IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODIFICATIONS TO 'tpu' job
gpu:
strategy:
fail-fast: false
Expand Down
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
hooks:
- id: codespell
name: Running codespell for typos
entry: codespell -w --skip="*.txt,pylintrc,.*" .
2 changes: 1 addition & 1 deletion MaxText/accelerator_to_spec_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
""" Static map of TPU names such as v4-8 to properties such as chip layout."""

""" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
IF YOU MODIFY THIS FILE YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO
IF YOU MODIFY THIS FILE YOU SHOULD ALSO ADD CORRESPONDING MODIFICATIONS TO
UserFacingNameToSystemCharacteristics in xpk/xpk.py !!!!! """

from dataclasses import dataclass
Expand Down
4 changes: 2 additions & 2 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
dcn_sequence_parallelism: 1 # never recommended
dcn_tensor_parallelism: 1 # never recommeneded
dcn_tensor_parallelism: 1 # never recommended
dcn_autoregressive_parallelism: 1 # never recommended
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
Expand Down Expand Up @@ -197,7 +197,7 @@ prefill_cache_dir: "" # If set and load_from_prefill_dir, decode.py reads from d
autoregressive_decode_assert: ""

enable_profiler: False
# If set to true, upload all profiler xplane results from all hosts. Otherwise, only upload the xplane reuslt from the first host.
# If set to true, upload all profiler xplane results from all hosts. Otherwise, only upload the xplane result from the first host.
upload_all_profiler_results: False
# Skip first n steps for profiling, to omit things like compilation and to give
# the iteration time a chance to stabilize.
Expand Down
2 changes: 1 addition & 1 deletion MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

# pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports
"""Trasforms a "full state" including optimzer state to a bfloat16 "parameter state" without optimizer state.
"""Transforms a "full state" including optimizer state to a bfloat16 "parameter state" without optimizer state.
This typically used for turning a state output by training.py into a state than can be consumed by decode.py.
The input "fullstate" is passed in via:
Expand Down
6 changes: 3 additions & 3 deletions MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def reduce_concat_tokens(
):
"""Token-preprocessor to concatenate multiple unrelated documents.
If we want to generate examples of exactly the right length,
(to avoid wasting space on padding), then we use this function, folowed by
(to avoid wasting space on padding), then we use this function, followed by
split_tokens.
Args:
dataset: a tf.data.Dataset with dictionaries containing the key feature_key.
Expand Down Expand Up @@ -219,7 +219,7 @@ def get_datasets(
train_ds = rekey(train_ds, {"inputs": None, "targets": "text"})

eval_ds = eval_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
# note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and splitted to target_length
# note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and split to target_length
# mainly to avoid eval sequences change depending on the number of hosts
eval_ds = rekey(eval_ds, {"inputs": None, "targets": "ids"})

Expand All @@ -243,7 +243,7 @@ def preprocess_dataset(
train_ds = split_tokens_to_targets_length(train_ds, config.max_target_length)
train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed)

# note eval_ds is pre tokenized, reduce_concated and splitted to target_length
# note eval_ds is pre tokenized, reduce_concated and split to target_length
# mainly to avoid eval sequences change depending on the number of hosts
train_ds = sequence_packing.pack_dataset(train_ds, config.max_target_length)
eval_ds = sequence_packing.pack_dataset(eval_ds, config.max_target_length)
Expand Down
2 changes: 1 addition & 1 deletion MaxText/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __next__(self):

@staticmethod
def raw_generate_synthetic_data(config):
"""Generates a single batch of syntehtic data"""
"""Generates a single batch of synthetic data"""
output = {}
output["inputs"] = jax.numpy.zeros((config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32)
output["inputs_position"] = jax.numpy.zeros(
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
raise ValueError("num_kv_heads is not defined.")

if self.num_query_heads % self.num_kv_heads != 0:
raise ValueError("Invaid num_kv_heads for GQA.")
raise ValueError("Invalid num_kv_heads for GQA.")

kv_proj = DenseGeneral(
features=(self.num_kv_heads, self.head_dim),
Expand Down Expand Up @@ -918,7 +918,7 @@ def __call__(
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
There are three modes: training, prefill and autoregression. During training, the KV cahce
There are three modes: training, prefill and autoregression. During training, the KV cache
is ignored. During prefill, the cache is filled. During autoregression the cache is used.
In the cache initialization call, `inputs_q` has a shape [batch, length,
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def compute_dot_general(inputs, kernel, axis, contract_ind):
kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
if quantizations.in_serve_mode(self.quant):
# During aqt convert state we delete kernel weight from params to save memory.
# Instead they are retreived from the tensors stored in the 'aqt' collection.
# Instead they are retrieved from the tensors stored in the 'aqt' collection.
kernel = jnp.zeros(kernel_shape)
else:
kernel = self.param(
Expand Down
6 changes: 3 additions & 3 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,15 +526,15 @@ def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, z_loss:
logits: [batch, length, num_classes] float array.
targets: categorical one-hot targets [batch, length, num_classes] float
array.
z_loss: coefficient for auxilliary z-loss loss term.
z_loss: coefficient for auxiliary z-loss loss term.
Returns:
tuple with the total loss and the z_loss, both
float arrays with shape [batch, length].
"""
logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
log_softmax = logits - logits_sum
loss = -jnp.sum(targets * log_softmax, axis=-1)
# Add auxilliary z-loss term.
# Add auxiliary z-loss term.
log_z = jnp.squeeze(logits_sum, axis=-1)
total_z_loss = z_loss * jax.lax.square(log_z)
loss += total_z_loss
Expand All @@ -554,7 +554,7 @@ def _cross_entropy_with_logits_fwd(
sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True)
log_softmax = shifted - jnp.log(sum_exp)
loss = -jnp.sum(targets * log_softmax, axis=-1)
# Add auxilliary z-loss term.
# Add auxiliary z-loss term.
log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1)
total_z_loss = z_loss * jax.lax.square(log_z)
loss += total_z_loss
Expand Down
2 changes: 1 addition & 1 deletion MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01):
perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding
assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, (
"Number of parameters per chip must not be less than in the ideal sharded "
"scenario accross `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes."
"scenario across `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes."
)
assert total_num_params_per_chip / perfectly_sharded_params_per_chip - 1 < tolerance, (
f"Number of unsharded parameters exceeds tolerance {tolerance * 100}% " "of total parameters."
Expand Down
2 changes: 1 addition & 1 deletion MaxText/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def adam_pax(
) -> optax.GradientTransformation:
"""Standard Adam optimizer that supports weight decay.
Follows the implemenation in pax/praxis sharded_adam
Follows the implementation in pax/praxis sharded_adam
https://github.com/google/praxis/blob/545e00ab126b823265d70c715950d39333484f38/praxis/optimizers.py#L621
Args:
Expand Down
4 changes: 2 additions & 2 deletions MaxText/tests/gpt3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


def init_random_model_vars(model, rng, example_batch):
"""initialze random model vars."""
"""initialize random model vars."""
model_vars = model.init(
{"params": rng, "aqt": rng},
example_batch["inputs"],
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_logits_numerically(self):
# ground truth values are calculated from paxml after loading above model_vars
# note we expect all xents are the same except the padding one since:
# paxml applies padding in mlp layer
# while maxtext implementaiton applies padding in attention mask instead
# while maxtext implementation applies padding in attention mask instead
# the two implementation are equivalent in valid non-padding tokens
per_example_xent_truth = jnp.array([[31.976467, 25.806253, 17.311134, 45.362663, 0.0]], dtype=jnp.float32)
logits, _ = self.model.apply(
Expand Down
2 changes: 1 addition & 1 deletion MaxText/tests/llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def apply_rotary_emb(
freqs_cis: jnp.ndarray,
dtype: jnp.dtype = jnp.bfloat16,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Apply the computed Rotary Postional Embedding"""
"""Apply the computed Rotary Positional Embedding"""
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)

Expand Down
2 changes: 1 addition & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def validate_train_config(config):
max_logging.log("WARNING: 'dataset_path' might be pointing your local file system")
if not config.base_output_directory.startswith("gs://"):
max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system")
assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive interger."
assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer."


def get_first_step(state):
Expand Down
2 changes: 1 addition & 1 deletion MaxText/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_topology_mesh(config):

def get_shaped_inputs(topology_mesh, config):
"""Get shaped abstractions of inputs to train_step: state, batch and rng"""
# Construct the model and optimizier to get shaped versions of the state
# Construct the model and optimizer to get shaped versions of the state
quant = quantizations.configure_quantization(config)
model = Transformer(config, topology_mesh, quant=quant)
# The learning_rate_schedule is baked into the compiled object.
Expand Down
2 changes: 1 addition & 1 deletion MaxText/vertex_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def setup(self):
return tensorboard_url

def upload_data(self, tensorboard_dir):
"""Starts an uploader to continously monitor and upload data to Vertex Tensorboard.
"""Starts an uploader to continuously monitor and upload data to Vertex Tensorboard.
Args:
tensorboard_dir: directory that contains Tensorboard data.
Expand Down
2 changes: 2 additions & 0 deletions constraints_gpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ pandas==2.2.1
platformdirs==4.2.0
pluggy==1.4.0
portpicker==1.6.0
pre-commit==3.7.0
promise==2.3
proto-plus==1.23.0
protobuf==3.20.3
Expand All @@ -107,6 +108,7 @@ pydantic==1.10.14
pydot==2.0.0
pyglove==0.4.4
Pygments==2.17.2
pyink==24.3.0
pylint==3.1.0
pyparsing==3.1.2
pytest==8.1.1
Expand Down
2 changes: 1 addition & 1 deletion end_to_end/tpu/gemma/2b/test_gemma.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ idx=$(date +%Y-%m-%d-%H-%M)
export MODEL_VARIATION='2b'

# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
# Non-Googlers please remember to use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
# Non-Googlers please remember to use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing.
export CHKPT_BUCKET=gs://maxtext-gemma/flax
export MODEL_BUCKET=gs://maxtext-gemma
Expand Down
2 changes: 1 addition & 1 deletion end_to_end/tpu/gemma/7b/1_test_gemma.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ MODEL_VARIATION='7b'


# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
# Please use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($BASE_OUTPUT_PATH).
# Please use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($BASE_OUTPUT_PATH).
# Non-Googlers please remember to point CHKPT_BUCKET to GCS buckets that you own
export CHKPT_BUCKET=gs://maxtext-gemma/flax

Expand Down
2 changes: 1 addition & 1 deletion getting_started/Run_MaxText_via_multihost_runner.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ either be a TPUVM or not, but it cannot be one of the workers. If your runner ma
Choose names for your TPUs and QR:
```
TPU_PREFIX=$YOUR_TPU_NAME # Use new names when you create new TPUs
QR_ID=$TPU_PREFIX # Convenient to re-use the node names, but can be different
QR_ID=$TPU_PREFIX # Convenient to reuse the node names, but can be different
```
Choose the number of nodes (we use 2 below, but you may customize this and other feature of your TPU(s))
```
Expand Down
5 changes: 2 additions & 3 deletions maxtext_gpu_dependencies.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ RUN mkdir -p /deps
# Set the working directory in the container
WORKDIR /deps

# Copy necessary build files to docker container
COPY setup.sh requirements.txt constraints_gpu.txt /deps/
# Copy all files from local workspace into docker container
COPY . .
RUN ls .

RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}

COPY . .

WORKDIR /deps
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ optax
protobuf==3.20.3
pylint
pytest
pyink
pre-commit
pytype
sentencepiece==0.1.97
tensorflow-text>=2.13.0
Expand Down
3 changes: 2 additions & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ fi
# Install dependencies from requirements.txt
cd $run_name_folder_path && pip install --upgrade pip
if [[ "$MODE" == "pinned" ]]; then
pip3 install -r requirements.txt -c constraints_gpu.txt
pip3 install -U -r requirements.txt -c constraints_gpu.txt
else
pip3 install -U -r requirements.txt
fi
[ -d ".git" ] && pre-commit install

0 comments on commit ad072dd

Please sign in to comment.