Skip to content

Commit

Permalink
Add support for installing nightly version of`tensorboard_plugin_prof…
Browse files Browse the repository at this point in the history
…ile` (#194)

* initial commit

* fix pylint bug

* uninstall tbp-nightly ahead of time

* stable tbp

* use upgrade flag

* fix comment
  • Loading branch information
shahyash10 authored Oct 4, 2023
1 parent 7c9150f commit 8a4b417
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
3 changes: 1 addition & 2 deletions MaxText/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P

import dataclasses
import functools
import operator
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -656,14 +655,14 @@ class Embed(nn.Module):
dtype: the dtype of the embedding vectors (default: float32).
embedding_init: embedding initializer.
"""
# pylint: disable=attribute-defined-outside-init
config: Config
num_embeddings: int
features: int
cast_input_dtype: Optional[DType] = None
dtype: DType = jnp.float32
attend_dtype: Optional[DType] = None
embedding_init: Initializer = default_embed_init
embedding: Array = dataclasses.field(init=False)

def setup(self):
self.embedding = self.param(
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,5 @@ pytest
sentencepiece==0.1.97
tensorflow==2.13.*
tensorflow-datasets
tensorboard-plugin-profile
tensorflow-text
tensorboardx
8 changes: 7 additions & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ fi
# Save the script folder path of maxtext
run_name_folder_path=$(pwd)

# Uninstall existing jax, jaxlib, and libtpu-nightly
# Uninstall existing jax, jaxlib and libtpu-nightly
pip3 show jax && pip3 uninstall -y jax
pip3 show jaxlib && pip3 uninstall -y jaxlib
pip3 show libtpu-nightly && pip3 uninstall -y libtpu-nightly
Expand Down Expand Up @@ -104,6 +104,8 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
# Copy libtpu.so from GCS path
gsutil cp "$LIBTPU_GCS_PATH" "$libtpu_path"
fi
echo "Installing stable tensorboard plugin profile"
pip3 install tensorboard-plugin-profile --upgrade
elif [[ $MODE == "nightly" ]]; then
# Nightly mode
echo "Installing jax-head, jaxlib-nightly"
Expand All @@ -124,6 +126,8 @@ elif [[ $MODE == "nightly" ]]; then
echo "Installing libtpu-nightly"
pip3 install libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -U --pre
fi
echo "Installing nightly tensorboard plugin profile"
pip3 install tbp-nightly --upgrade
elif [[ $MODE == "head" ]]; then
# Head mode
if [[ -n "$LIBTPU_GCS_PATH" ]]; then
Expand All @@ -150,6 +154,8 @@ elif [[ $MODE == "head" ]]; then
cd $HOME/jax
python3 build/build.py --enable_tpu --bazel_options="--override_repository=xla=$HOME/xla"
pip3 install dist/jaxlib-*-cp*-manylinux2014_x86_64.whl --force-reinstall --no-deps
echo "Installing nightly tensorboard plugin profile"
pip3 install tbp-nightly --upgrade
else
echo -e "\n\nError: You can only set MODE to [stable,nightly,head,libtpu-only].\n\n"
exit 1
Expand Down

0 comments on commit 8a4b417

Please sign in to comment.