Skip to content

Commit

Permalink
getting there with args
Browse files Browse the repository at this point in the history
  • Loading branch information
gobbleturk committed Oct 27, 2023
1 parent 993b7b9 commit 62a43c7
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions run_xpk_sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@

args = {
'dryrun': True,
'tpu': 'v4', # 'v4' 'v5'
'stable': False,
'cluster': False,
'docker_image': '',
'tpu-type': 'v4', # 'v4' 'v5litepod-256'


}


Expand All @@ -24,17 +27,15 @@ def update_yaml_fields(yaml_data, update_dict, allow_new_keys=False):


BASE_CMD="""export LIBTPU_INIT_ARGS="--xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" && \
bash setup_with_retries.sh && \
bash rto_setup.sh && \
python3 MaxText/train.py """
python3 MaxText/train.py MaxText/configs/base.yml"""

def bname(b: bool):
assert b == True or b == False, f'not bool: "{b}"'
return str(b)[0]

def run_job(run_name, base_config, **config_updates):
maxtext_config = update_yaml_fields(base_config, config_updates)
model_size = maxtext_config['global_parameter_scale']
# model_size = maxtext_config['global_parameter_scale']
with open('MaxText/configs/base.yml', 'r') as file:
base_yml = yaml.safe_load(file)

Expand All @@ -60,10 +61,7 @@ def calc_chinchilla_step_count(num_params_billions, num_slice, seqs_per_chip, to

attempt = args['attempt']
sweep_name = args['sweep']
use_cl = args['jax_14_cl']
assert use_cl, 'forbidden to not use it'
run_name = f'int8-{sweep_name}-a{attempt}-{run_name}'

jobre = args['jobre']
url = f"https://pantheon.corp.google.com/logs/query;query=timestamp%20%3E%20%222023-08-18%22%20AND%20labels.%22agent.googleapis.com%2Flog_file_path%22%3D~%22{run_name}.*%2Fmain_command_log_slice_0_worker_0%22"
if not re.findall(jobre, run_name):
Expand All @@ -77,10 +75,7 @@ def calc_chinchilla_step_count(num_params_billions, num_slice, seqs_per_chip, to
with open(experiment_yml_file, 'w') as file:
yaml.dump(yml, file)

if args['jax_14_cl']:
mhj_cmd = BASE_MHJ_CMD_14_CP
else:
mhj_cmd = BASE_MHJ_CMD
xpk_cmd = BASE_XPK_CMD

experiment_mhj = {
'--RUN_NAME': run_name,
Expand Down

0 comments on commit 62a43c7

Please sign in to comment.