From ad072dd633c1736a345cce6fd6b5aacb931989d3 Mon Sep 17 00:00:00 2001 From: khatwanimohit Date: Thu, 18 Apr 2024 03:49:27 +0000 Subject: [PATCH] Pre-commit config --- .github/workflows/UnitTests.yml | 4 ++-- .pre-commit-config.yaml | 7 +++++++ MaxText/accelerator_to_spec_map.py | 2 +- MaxText/configs/base.yml | 4 ++-- MaxText/generate_param_only_checkpoint.py | 2 +- MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py | 6 +++--- MaxText/input_pipeline/input_pipeline_interface.py | 2 +- MaxText/layers/attentions.py | 4 ++-- MaxText/layers/linears.py | 2 +- MaxText/max_utils.py | 6 +++--- MaxText/maxtext_utils.py | 2 +- MaxText/optimizers.py | 2 +- MaxText/tests/gpt3_test.py | 4 ++-- MaxText/tests/llama_test.py | 2 +- MaxText/train.py | 2 +- MaxText/train_compile.py | 2 +- MaxText/vertex_tensorboard.py | 2 +- constraints_gpu.txt | 2 ++ end_to_end/tpu/gemma/2b/test_gemma.sh | 2 +- end_to_end/tpu/gemma/7b/1_test_gemma.sh | 2 +- getting_started/Run_MaxText_via_multihost_runner.md | 2 +- maxtext_gpu_dependencies.Dockerfile | 5 ++--- requirements.txt | 2 ++ setup.sh | 3 ++- 24 files changed, 42 insertions(+), 31 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 08ff79a8c..e7a05f3a6 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -27,7 +27,7 @@ on: - cron: '0 */2 * * *' jobs: - # IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'gpu' job + # IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODIFICATIONS TO 'gpu' job tpu: strategy: fail-fast: false @@ -99,7 +99,7 @@ jobs: docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 pedagogical_examples/shmap_collective_matmul.py' - # IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'tpu' job + # IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODIFICATIONS TO 'tpu' job gpu: strategy: fail-fast: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..87e563baf --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/codespell-project/codespell + rev: v2.2.4 + hooks: + - id: codespell + name: Running codespell for typos + entry: codespell -w --skip="*.txt,pylintrc,.*" . diff --git a/MaxText/accelerator_to_spec_map.py b/MaxText/accelerator_to_spec_map.py index 255aef965..27267048c 100644 --- a/MaxText/accelerator_to_spec_map.py +++ b/MaxText/accelerator_to_spec_map.py @@ -17,7 +17,7 @@ """ Static map of TPU names such as v4-8 to properties such as chip layout.""" """ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -IF YOU MODIFY THIS FILE YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO +IF YOU MODIFY THIS FILE YOU SHOULD ALSO ADD CORRESPONDING MODIFICATIONS TO UserFacingNameToSystemCharacteristics in xpk/xpk.py !!!!! """ from dataclasses import dataclass diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index c5a8f9098..1aa573eae 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -142,7 +142,7 @@ dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 dcn_fsdp_transpose_parallelism: 1 dcn_sequence_parallelism: 1 # never recommended -dcn_tensor_parallelism: 1 # never recommeneded +dcn_tensor_parallelism: 1 # never recommended dcn_autoregressive_parallelism: 1 # never recommended ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded @@ -197,7 +197,7 @@ prefill_cache_dir: "" # If set and load_from_prefill_dir, decode.py reads from d autoregressive_decode_assert: "" enable_profiler: False -# If set to true, upload all profiler xplane results from all hosts. Otherwise, only upload the xplane reuslt from the first host. +# If set to true, upload all profiler xplane results from all hosts. Otherwise, only upload the xplane result from the first host. upload_all_profiler_results: False # Skip first n steps for profiling, to omit things like compilation and to give # the iteration time a chance to stabilize. diff --git a/MaxText/generate_param_only_checkpoint.py b/MaxText/generate_param_only_checkpoint.py index 09ea7412b..673b3ce51 100644 --- a/MaxText/generate_param_only_checkpoint.py +++ b/MaxText/generate_param_only_checkpoint.py @@ -15,7 +15,7 @@ """ # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports -"""Trasforms a "full state" including optimzer state to a bfloat16 "parameter state" without optimizer state. +"""Transforms a "full state" including optimizer state to a bfloat16 "parameter state" without optimizer state. This typically used for turning a state output by training.py into a state than can be consumed by decode.py. The input "fullstate" is passed in via: diff --git a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py index 0ca44bdb1..caae39b77 100644 --- a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py +++ b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py @@ -91,7 +91,7 @@ def reduce_concat_tokens( ): """Token-preprocessor to concatenate multiple unrelated documents. If we want to generate examples of exactly the right length, - (to avoid wasting space on padding), then we use this function, folowed by + (to avoid wasting space on padding), then we use this function, followed by split_tokens. Args: dataset: a tf.data.Dataset with dictionaries containing the key feature_key. @@ -219,7 +219,7 @@ def get_datasets( train_ds = rekey(train_ds, {"inputs": None, "targets": "text"}) eval_ds = eval_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) - # note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and splitted to target_length + # note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and split to target_length # mainly to avoid eval sequences change depending on the number of hosts eval_ds = rekey(eval_ds, {"inputs": None, "targets": "ids"}) @@ -243,7 +243,7 @@ def preprocess_dataset( train_ds = split_tokens_to_targets_length(train_ds, config.max_target_length) train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed) - # note eval_ds is pre tokenized, reduce_concated and splitted to target_length + # note eval_ds is pre tokenized, reduce_concated and split to target_length # mainly to avoid eval sequences change depending on the number of hosts train_ds = sequence_packing.pack_dataset(train_ds, config.max_target_length) eval_ds = sequence_packing.pack_dataset(eval_ds, config.max_target_length) diff --git a/MaxText/input_pipeline/input_pipeline_interface.py b/MaxText/input_pipeline/input_pipeline_interface.py index 47a784829..81d3eaa75 100644 --- a/MaxText/input_pipeline/input_pipeline_interface.py +++ b/MaxText/input_pipeline/input_pipeline_interface.py @@ -112,7 +112,7 @@ def __next__(self): @staticmethod def raw_generate_synthetic_data(config): - """Generates a single batch of syntehtic data""" + """Generates a single batch of synthetic data""" output = {} output["inputs"] = jax.numpy.zeros((config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32) output["inputs_position"] = jax.numpy.zeros( diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index b1c0803d7..227267bc7 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -853,7 +853,7 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: raise ValueError("num_kv_heads is not defined.") if self.num_query_heads % self.num_kv_heads != 0: - raise ValueError("Invaid num_kv_heads for GQA.") + raise ValueError("Invalid num_kv_heads for GQA.") kv_proj = DenseGeneral( features=(self.num_kv_heads, self.head_dim), @@ -918,7 +918,7 @@ def __call__( Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. - There are three modes: training, prefill and autoregression. During training, the KV cahce + There are three modes: training, prefill and autoregression. During training, the KV cache is ignored. During prefill, the cache is filled. During autoregression the cache is used. In the cache initialization call, `inputs_q` has a shape [batch, length, diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 3d3f35b9b..b5bf67cef 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -120,7 +120,7 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) if quantizations.in_serve_mode(self.quant): # During aqt convert state we delete kernel weight from params to save memory. - # Instead they are retreived from the tensors stored in the 'aqt' collection. + # Instead they are retrieved from the tensors stored in the 'aqt' collection. kernel = jnp.zeros(kernel_shape) else: kernel = self.param( diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index d4de84b41..cce01bfee 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -526,7 +526,7 @@ def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, z_loss: logits: [batch, length, num_classes] float array. targets: categorical one-hot targets [batch, length, num_classes] float array. - z_loss: coefficient for auxilliary z-loss loss term. + z_loss: coefficient for auxiliary z-loss loss term. Returns: tuple with the total loss and the z_loss, both float arrays with shape [batch, length]. @@ -534,7 +534,7 @@ def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, z_loss: logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) log_softmax = logits - logits_sum loss = -jnp.sum(targets * log_softmax, axis=-1) - # Add auxilliary z-loss term. + # Add auxiliary z-loss term. log_z = jnp.squeeze(logits_sum, axis=-1) total_z_loss = z_loss * jax.lax.square(log_z) loss += total_z_loss @@ -554,7 +554,7 @@ def _cross_entropy_with_logits_fwd( sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True) log_softmax = shifted - jnp.log(sum_exp) loss = -jnp.sum(targets * log_softmax, axis=-1) - # Add auxilliary z-loss term. + # Add auxiliary z-loss term. log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) total_z_loss = z_loss * jax.lax.square(log_z) loss += total_z_loss diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 597f696f4..3a2b84d13 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -169,7 +169,7 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01): perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, ( "Number of parameters per chip must not be less than in the ideal sharded " - "scenario accross `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes." + "scenario across `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes." ) assert total_num_params_per_chip / perfectly_sharded_params_per_chip - 1 < tolerance, ( f"Number of unsharded parameters exceeds tolerance {tolerance * 100}% " "of total parameters." diff --git a/MaxText/optimizers.py b/MaxText/optimizers.py index 63fcc42b1..1d15b358f 100644 --- a/MaxText/optimizers.py +++ b/MaxText/optimizers.py @@ -59,7 +59,7 @@ def adam_pax( ) -> optax.GradientTransformation: """Standard Adam optimizer that supports weight decay. - Follows the implemenation in pax/praxis sharded_adam + Follows the implementation in pax/praxis sharded_adam https://github.com/google/praxis/blob/545e00ab126b823265d70c715950d39333484f38/praxis/optimizers.py#L621 Args: diff --git a/MaxText/tests/gpt3_test.py b/MaxText/tests/gpt3_test.py index 7dc4246d7..b1f0bed52 100644 --- a/MaxText/tests/gpt3_test.py +++ b/MaxText/tests/gpt3_test.py @@ -35,7 +35,7 @@ def init_random_model_vars(model, rng, example_batch): - """initialze random model vars.""" + """initialize random model vars.""" model_vars = model.init( {"params": rng, "aqt": rng}, example_batch["inputs"], @@ -90,7 +90,7 @@ def test_logits_numerically(self): # ground truth values are calculated from paxml after loading above model_vars # note we expect all xents are the same except the padding one since: # paxml applies padding in mlp layer - # while maxtext implementaiton applies padding in attention mask instead + # while maxtext implementation applies padding in attention mask instead # the two implementation are equivalent in valid non-padding tokens per_example_xent_truth = jnp.array([[31.976467, 25.806253, 17.311134, 45.362663, 0.0]], dtype=jnp.float32) logits, _ = self.model.apply( diff --git a/MaxText/tests/llama_test.py b/MaxText/tests/llama_test.py index 6d7b7827c..8b60d4bb2 100644 --- a/MaxText/tests/llama_test.py +++ b/MaxText/tests/llama_test.py @@ -48,7 +48,7 @@ def apply_rotary_emb( freqs_cis: jnp.ndarray, dtype: jnp.dtype = jnp.bfloat16, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Apply the computed Rotary Postional Embedding""" + """Apply the computed Rotary Positional Embedding""" reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) diff --git a/MaxText/train.py b/MaxText/train.py index 9baca72ec..3c1b3ac07 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -74,7 +74,7 @@ def validate_train_config(config): max_logging.log("WARNING: 'dataset_path' might be pointing your local file system") if not config.base_output_directory.startswith("gs://"): max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system") - assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive interger." + assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer." def get_first_step(state): diff --git a/MaxText/train_compile.py b/MaxText/train_compile.py index 43789ea3f..ee8551031 100644 --- a/MaxText/train_compile.py +++ b/MaxText/train_compile.py @@ -69,7 +69,7 @@ def get_topology_mesh(config): def get_shaped_inputs(topology_mesh, config): """Get shaped abstractions of inputs to train_step: state, batch and rng""" - # Construct the model and optimizier to get shaped versions of the state + # Construct the model and optimizer to get shaped versions of the state quant = quantizations.configure_quantization(config) model = Transformer(config, topology_mesh, quant=quant) # The learning_rate_schedule is baked into the compiled object. diff --git a/MaxText/vertex_tensorboard.py b/MaxText/vertex_tensorboard.py index 9a106c32b..35c8ecc5e 100644 --- a/MaxText/vertex_tensorboard.py +++ b/MaxText/vertex_tensorboard.py @@ -72,7 +72,7 @@ def setup(self): return tensorboard_url def upload_data(self, tensorboard_dir): - """Starts an uploader to continously monitor and upload data to Vertex Tensorboard. + """Starts an uploader to continuously monitor and upload data to Vertex Tensorboard. Args: tensorboard_dir: directory that contains Tensorboard data. diff --git a/constraints_gpu.txt b/constraints_gpu.txt index e66ad56f6..e94f9c18e 100644 --- a/constraints_gpu.txt +++ b/constraints_gpu.txt @@ -96,6 +96,7 @@ pandas==2.2.1 platformdirs==4.2.0 pluggy==1.4.0 portpicker==1.6.0 +pre-commit==3.7.0 promise==2.3 proto-plus==1.23.0 protobuf==3.20.3 @@ -107,6 +108,7 @@ pydantic==1.10.14 pydot==2.0.0 pyglove==0.4.4 Pygments==2.17.2 +pyink==24.3.0 pylint==3.1.0 pyparsing==3.1.2 pytest==8.1.1 diff --git a/end_to_end/tpu/gemma/2b/test_gemma.sh b/end_to_end/tpu/gemma/2b/test_gemma.sh index 74d776951..7ba09c1d9 100644 --- a/end_to_end/tpu/gemma/2b/test_gemma.sh +++ b/end_to_end/tpu/gemma/2b/test_gemma.sh @@ -15,7 +15,7 @@ idx=$(date +%Y-%m-%d-%H-%M) export MODEL_VARIATION='2b' # After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ -# Non-Googlers please remember to use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). +# Non-Googlers please remember to use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). # Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. export CHKPT_BUCKET=gs://maxtext-gemma/flax export MODEL_BUCKET=gs://maxtext-gemma diff --git a/end_to_end/tpu/gemma/7b/1_test_gemma.sh b/end_to_end/tpu/gemma/7b/1_test_gemma.sh index 521dd6550..f5db1a636 100644 --- a/end_to_end/tpu/gemma/7b/1_test_gemma.sh +++ b/end_to_end/tpu/gemma/7b/1_test_gemma.sh @@ -17,7 +17,7 @@ MODEL_VARIATION='7b' # After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ -# Please use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($BASE_OUTPUT_PATH). +# Please use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($BASE_OUTPUT_PATH). # Non-Googlers please remember to point CHKPT_BUCKET to GCS buckets that you own export CHKPT_BUCKET=gs://maxtext-gemma/flax diff --git a/getting_started/Run_MaxText_via_multihost_runner.md b/getting_started/Run_MaxText_via_multihost_runner.md index 14b215f8e..31a04e8e1 100644 --- a/getting_started/Run_MaxText_via_multihost_runner.md +++ b/getting_started/Run_MaxText_via_multihost_runner.md @@ -54,7 +54,7 @@ either be a TPUVM or not, but it cannot be one of the workers. If your runner ma Choose names for your TPUs and QR: ``` TPU_PREFIX=$YOUR_TPU_NAME # Use new names when you create new TPUs - QR_ID=$TPU_PREFIX # Convenient to re-use the node names, but can be different + QR_ID=$TPU_PREFIX # Convenient to reuse the node names, but can be different ``` Choose the number of nodes (we use 2 below, but you may customize this and other feature of your TPU(s)) ``` diff --git a/maxtext_gpu_dependencies.Dockerfile b/maxtext_gpu_dependencies.Dockerfile index 389dd234e..47cd646eb 100644 --- a/maxtext_gpu_dependencies.Dockerfile +++ b/maxtext_gpu_dependencies.Dockerfile @@ -38,13 +38,12 @@ RUN mkdir -p /deps # Set the working directory in the container WORKDIR /deps -# Copy necessary build files to docker container -COPY setup.sh requirements.txt constraints_gpu.txt /deps/ +# Copy all files from local workspace into docker container +COPY . . RUN ls . RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}" RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE} -COPY . . WORKDIR /deps diff --git a/requirements.txt b/requirements.txt index ebe9a2e4f..c0bc46b26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,8 @@ optax protobuf==3.20.3 pylint pytest +pyink +pre-commit pytype sentencepiece==0.1.97 tensorflow-text>=2.13.0 diff --git a/setup.sh b/setup.sh index 0388fac83..7e1a5f152 100644 --- a/setup.sh +++ b/setup.sh @@ -195,7 +195,8 @@ fi # Install dependencies from requirements.txt cd $run_name_folder_path && pip install --upgrade pip if [[ "$MODE" == "pinned" ]]; then - pip3 install -r requirements.txt -c constraints_gpu.txt + pip3 install -U -r requirements.txt -c constraints_gpu.txt else pip3 install -U -r requirements.txt fi +[ -d ".git" ] && pre-commit install \ No newline at end of file