Skip to content

Commit b7daa82

Browse files
author
maxtex authors
committed
Merge pull request #328 from google:rwitten_add_env_vars
PiperOrigin-RevId: 598020582
2 parents 010fb68 + f6d0868 commit b7daa82

File tree

4 files changed

+41
-16
lines changed

4 files changed

+41
-16
lines changed

MaxText/max_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
153153

154154
return parallelism_vals
155155

156-
def create_device_mesh(config, devices=None, logging=True):
156+
def create_device_mesh(config, devices=None):
157157
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas """
158158
if devices is None:
159159
devices = jax.devices()
@@ -163,7 +163,6 @@ def create_device_mesh(config, devices=None, logging=True):
163163
except:
164164
num_slices = 1
165165
num_devices_per_slice = num_devices//num_slices
166-
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
167166

168167
multi_slice_env = num_slices > 1
169168

@@ -183,8 +182,7 @@ def create_device_mesh(config, devices=None, logging=True):
183182
else:
184183
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
185184

186-
if logging:
187-
max_logging.log(f"Decided on mesh: {mesh}")
185+
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
188186

189187
return mesh
190188

MaxText/pyconfig.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030

3131
from typing import Any, Union
3232

33+
_MAX_PREFIX = "M_"
34+
def yaml_key_to_env_key(s: str) -> str:
35+
return _MAX_PREFIX + s.upper()
36+
3337
def string_to_bool(s: str) -> bool:
3438
if s.lower() == "true":
3539
return True
@@ -61,9 +65,17 @@ def _lists_to_tuples(l: list[Any]) -> Union[tuple[Any],list[Any]]:
6165

6266
class _HyperParameters():
6367
# pylint: disable=missing-class-docstring
68+
def _validate_env_variables(self, raw_data_from_yaml):
69+
for environment_var in os.environ:
70+
if environment_var[:len(_MAX_PREFIX)] == _MAX_PREFIX:
71+
proposed_key = environment_var[len(_MAX_PREFIX):].lower()
72+
if proposed_key not in raw_data_from_yaml:
73+
raise ValueError(f"We received env {environment_var} but it doesn't match a key, so it is aassumed a mistake")
74+
6475
def __init__(self, argv: list[str], **kwargs):
6576
with open(argv[1], "r", encoding="utf-8") as yaml_file:
6677
raw_data_from_yaml = yaml.safe_load(yaml_file)
78+
self._validate_env_variables(raw_data_from_yaml)
6779
raw_data_from_cmd_line = self._load_kwargs(argv, **kwargs)
6880

6981
for k in raw_data_from_cmd_line:
@@ -74,28 +86,43 @@ def __init__(self, argv: list[str], **kwargs):
7486

7587
raw_keys = OrderedDict()
7688
for k in raw_data_from_yaml:
77-
if k in raw_data_from_cmd_line and not isinstance(raw_data_from_cmd_line[k], type(raw_data_from_yaml[k])) and \
78-
type(raw_data_from_yaml[k]) not in _yaml_types_to_parser:
89+
if k in raw_data_from_cmd_line and yaml_key_to_env_key(k) in os.environ:
90+
raise ValueError(f"You are passing overrides by both CLI and ENV for `{k}`. This isn't allowed.")
91+
92+
if not k in raw_data_from_cmd_line and not yaml_key_to_env_key(k) in os.environ:
93+
raw_keys[k] = raw_data_from_yaml[k]
94+
continue
95+
96+
if k in raw_data_from_cmd_line:
97+
new_proposal = raw_data_from_cmd_line[k]
98+
else:
99+
new_proposal = os.environ.get(yaml_key_to_env_key(k))
100+
101+
if (not isinstance(new_proposal, type(raw_data_from_yaml[k]))) and \
102+
(type(raw_data_from_yaml[k]) not in _yaml_types_to_parser):
79103
raise ValueError(
80104
f"For key '{k}', type {type(raw_data_from_yaml[k])} not in {_yaml_types_to_parser.keys()}, can't pass"
81-
" at the command line"
105+
" at the CLI or ENV"
82106
)
83107

84-
if k in raw_data_from_cmd_line and isinstance(raw_data_from_cmd_line[k], type(raw_data_from_yaml[k])):
85-
raw_keys[k] = raw_data_from_cmd_line[k] # take the raw data, no type conversion
86-
elif k in raw_data_from_cmd_line:
108+
if isinstance(new_proposal, type(raw_data_from_yaml[k])):
109+
raw_keys[k] = new_proposal # take the raw data, no type conversion
110+
else:
87111
try:
88112
raw_keys[k] = _yaml_types_to_parser[type(raw_data_from_yaml[k])](
89-
raw_data_from_cmd_line[k]
113+
new_proposal
90114
) # take the command line value, but type it like the config value.
91115
except ValueError as e:
92-
raise ValueError(f"Couldn't parse value from command line '{raw_data_from_cmd_line[k]}' for key '{k}'") from e
93-
else:
94-
raw_keys[k] = raw_data_from_yaml[k]
116+
raise ValueError(f"Couldn't parse value from CLI or ENV '{new_proposal}' for key '{k}'") from e
95117

96118
_HyperParameters.update_model_vars(raw_keys)
97119
_HyperParameters.user_init(raw_keys)
120+
98121
self.keys = raw_keys
122+
keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension
123+
keys.sort()
124+
for k in keys:
125+
max_logging.log(f"Config param {k}: {raw_keys[k]}")
99126

100127
def _load_kwargs(self, argv: list[str], **kwargs):
101128
args_dict = dict(a.split("=") for a in argv[2:])

MaxText/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,6 @@ def main(argv: Sequence[str]) -> None:
345345
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
346346
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
347347
pyconfig.initialize(argv)
348-
print(f"Found {jax.device_count()} devices.")
349348
config = pyconfig.config
350349
validate_train_config(config)
351350
cc.initialize_cache(os.path.expanduser(config.jax_cache_dir))

end_to_end/llama_load_and_test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
set -e
22
idx=$(date +%Y-%m-%d-%H-%M)
3+
34
#TODO(internal bug -- migrate to XLML)
45
#pip install torch
56
#gsutil cp -r gs://maxtext-llama/llama2-7b/meta-ckpt /tmp/
67
#python3 MaxText/convert_llama_ckpt.py --base-model-path /tmp/meta-ckpt --model-size 7b --maxtext-model-path gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext/
78
#python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_${idx} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset load_from_other_directory=gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext/ load_from_other_directory_step=0 per_device_batch_size=1 model_name='llama2-7b' assets_path=gs://maxtext-llama/llama2-7b ici_tensor_parallelism=4 steps=1 max_prefill_predict_length=4 max_target_length=16 async_checkpointing=false prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product
89

9-
python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_${idx} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset load_from_other_directory=gs://maxtext-llama/test/2024-01-12-17-46/decode-ckpt-maxtext/ load_from_other_directory_step=0 per_device_batch_size=1 model_name='llama2-7b' assets_path=gs://maxtext-llama/llama2-7b ici_tensor_parallelism=4 steps=1 max_prefill_predict_length=4 max_target_length=16 async_checkpointing=false prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product
10+
python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_${idx} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset load_from_other_directory=gs://maxtext-llama/test/2024-01-12-17-46/decode-ckpt-maxtext/ load_from_other_directory_step=0 per_device_batch_size=1 model_name='llama2-7b' assets_path=gs://maxtext-llama/llama2-7b ici_tensor_parallelism=4 steps=1 max_prefill_predict_length=4 max_target_length=16 async_checkpointing=false prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product

0 commit comments

Comments
 (0)