From 62a43c793464c3723fd2aa6e2a90217bc6955bb0 Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Fri, 27 Oct 2023 17:57:34 +0000 Subject: [PATCH] getting there with args --- run_xpk_sweeps.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/run_xpk_sweeps.py b/run_xpk_sweeps.py index 68bd01eea..11ff9f063 100644 --- a/run_xpk_sweeps.py +++ b/run_xpk_sweeps.py @@ -9,8 +9,11 @@ args = { 'dryrun': True, - 'tpu': 'v4', # 'v4' 'v5' - 'stable': False, + 'cluster': False, + 'docker_image': '', + 'tpu-type': 'v4', # 'v4' 'v5litepod-256' + + } @@ -24,9 +27,7 @@ def update_yaml_fields(yaml_data, update_dict, allow_new_keys=False): BASE_CMD="""export LIBTPU_INIT_ARGS="--xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" && \ -bash setup_with_retries.sh && \ -bash rto_setup.sh && \ -python3 MaxText/train.py """ +python3 MaxText/train.py MaxText/configs/base.yml""" def bname(b: bool): assert b == True or b == False, f'not bool: "{b}"' @@ -34,7 +35,7 @@ def bname(b: bool): def run_job(run_name, base_config, **config_updates): maxtext_config = update_yaml_fields(base_config, config_updates) - model_size = maxtext_config['global_parameter_scale'] + # model_size = maxtext_config['global_parameter_scale'] with open('MaxText/configs/base.yml', 'r') as file: base_yml = yaml.safe_load(file) @@ -60,10 +61,7 @@ def calc_chinchilla_step_count(num_params_billions, num_slice, seqs_per_chip, to attempt = args['attempt'] sweep_name = args['sweep'] - use_cl = args['jax_14_cl'] - assert use_cl, 'forbidden to not use it' run_name = f'int8-{sweep_name}-a{attempt}-{run_name}' - jobre = args['jobre'] url = f"https://pantheon.corp.google.com/logs/query;query=timestamp%20%3E%20%222023-08-18%22%20AND%20labels.%22agent.googleapis.com%2Flog_file_path%22%3D~%22{run_name}.*%2Fmain_command_log_slice_0_worker_0%22" if not re.findall(jobre, run_name): @@ -77,10 +75,7 @@ def calc_chinchilla_step_count(num_params_billions, num_slice, seqs_per_chip, to with open(experiment_yml_file, 'w') as file: yaml.dump(yml, file) - if args['jax_14_cl']: - mhj_cmd = BASE_MHJ_CMD_14_CP - else: - mhj_cmd = BASE_MHJ_CMD + xpk_cmd = BASE_XPK_CMD experiment_mhj = { '--RUN_NAME': run_name,