Skip to content

Commit

Permalink
Add conv tests for BART and BERT (#68)
Browse files Browse the repository at this point in the history
* Add conv tests for BART and BERT

* Update timeout for bert_mnli

* Fix issues during merge
  • Loading branch information
RissyRan authored Jan 11, 2024
1 parent a27051d commit 24c2ff9
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 47 deletions.
191 changes: 163 additions & 28 deletions configs/xlml/jax/solutionsteam_flax_latest_supported_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
from configs import gcs_bucket, test_owner
from configs.xlml.jax import common
from configs.vm_resource import TpuVersion, Project, RuntimeVersion
from datetime import datetime


PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value
RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value
RUN_DATE = datetime.now().strftime("%Y_%m_%d")


def get_flax_resnet_config(
Expand Down Expand Up @@ -175,11 +177,9 @@ def get_flax_vit_conv_config(

set_up_cmds = get_flax_vit_setup_cmds()
tf_summary_location = (
"/tmp/transformers/vit-imagenette/events.out.tfevents.jax-vit.v2"
)
gcs_location = (
f"{gcs_bucket.XLML_OUTPUT_DIR}/flax/vit/metric/events.out.tfevents.jax-vit.v2"
"/tmp/transformers/vit-imagenette/events.out.tfevents.flax-vit.v2"
)
gcs_location = f"{gcs_bucket.XLML_OUTPUT_DIR}/flax/vit/{RUN_DATE}/events.out.tfevents.flax-vit.v2"
extra_run_cmds = (
(
"cp /tmp/transformers/vit-imagenette/events.out.tfevents.*"
Expand Down Expand Up @@ -351,38 +351,50 @@ def get_flax_sd_config(
)


def get_flax_bart_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
time_out_in_min: int,
extraFlags: str = "",
is_tpu_reserved: bool = True,
) -> task.TpuQueuedResourceTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=PROJECT_NAME,
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)

set_up_cmds = common.set_up_hugging_face_transformers() + (
def get_flax_bart_setup_cmds() -> Tuple[str]:
return common.set_up_hugging_face_transformers() + (
"pip install -r examples/flax/summarization/requirements.txt",
"pip install ml_dtypes==0.2.0",
)

run_model_cmds = (

def get_flax_bart_run_model_cmds(
num_train_epochs: int,
extraFlags: str = "",
extra_run_cmds: Tuple[str] = ("echo",),
) -> Tuple[str]:
return (
(
"cd /tmp/transformers/examples/flax/summarization &&"
" JAX_PLATFORM_NAME=TPU python3 run_summarization_flax.py"
" --model_name_or_path facebook/bart-base --tokenizer_name"
" facebook/bart-base --dataset_name wiki_summary --do_train"
" --do_eval --do_predict --predict_with_generate --learning_rate"
" 5e-5 --warmup_steps 0 --output_dir=./bart-base-wiki"
" --overwrite_output_dir --num_train_epochs 3 --max_source_length"
" 5e-5 --warmup_steps 0 --output_dir=/tmp/transformers/bart-base-wiki"
f" --overwrite_output_dir --num_train_epochs {num_train_epochs} --max_source_length"
f" 512 --max_target_length 64 {extraFlags}"
),
) + extra_run_cmds


def get_flax_bart_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
time_out_in_min: int,
num_train_epochs: int,
is_tpu_reserved: bool = True,
extraFlags: str = "",
) -> task.TpuQueuedResourceTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=PROJECT_NAME,
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)

set_up_cmds = get_flax_bart_setup_cmds()
run_model_cmds = get_flax_bart_run_model_cmds(num_train_epochs, extraFlags)

job_test_config = test_config.TpuVmTest(
test_config.Tpu(
version=tpu_version,
Expand All @@ -403,13 +415,12 @@ def get_flax_bart_config(
)


def get_flax_bert_config(
def get_flax_bart_conv_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
time_out_in_min: int,
task_name: str,
num_train_epochs: int = 1,
num_train_epochs: int,
extraFlags: str = "",
is_tpu_reserved: bool = True,
) -> task.TpuQueuedResourceTask:
Expand All @@ -419,21 +430,89 @@ def get_flax_bert_config(
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)

set_up_cmds = common.set_up_hugging_face_transformers() + (
set_up_cmds = get_flax_bart_setup_cmds()
work_dir = "/tmp/transformers/bart-base-wiki"
tf_summary_location = f"{work_dir}/events.out.tfevents.flax-bart.v2"
gcs_location = f"{gcs_bucket.XLML_OUTPUT_DIR}/flax/bart/{RUN_DATE}/events.out.tfevents.flax-bart.v2"
extra_run_cmds = (
f"cp {work_dir}/events.out.tfevents.* {tf_summary_location} || exit 0",
f"gsutil cp {tf_summary_location} {gcs_location} || exit 0",
)
run_model_cmds = get_flax_bart_run_model_cmds(
num_train_epochs, extraFlags, extra_run_cmds
)

job_test_config = test_config.TpuVmTest(
test_config.Tpu(
version=tpu_version,
cores=tpu_cores,
runtime_version=RUNTIME_IMAGE,
reserved=is_tpu_reserved,
),
test_name="flax_bart_wiki_conv",
set_up_cmds=set_up_cmds,
run_model_cmds=run_model_cmds,
time_out_in_min=time_out_in_min,
task_owner=test_owner.SHIVA_S,
)

job_metric_config = metric_config.MetricConfig(
tensorboard_summary=metric_config.SummaryConfig(
file_location=gcs_location,
aggregation_strategy=metric_config.AggregationStrategy.LAST,
)
)

return task.TpuQueuedResourceTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
task_metric_config=job_metric_config,
)


def get_flax_bert_setup_cmds() -> Tuple[str]:
return common.set_up_hugging_face_transformers() + (
"pip install -r examples/flax/text-classification/requirements.txt",
"pip install ml_dtypes==0.2.0",
)

run_model_cmds = (

def get_flax_bert_run_model_cmds(
task_name: str,
num_train_epochs: int,
extraFlags: str = "",
extra_run_cmds: Tuple[str] = ("echo",),
) -> Tuple[str]:
return (
(
"cd /tmp/transformers/examples/flax/text-classification &&"
" JAX_PLATFORM_NAME=TPU python3 run_flax_glue.py --output_dir"
" ./bert-glue --model_name_or_path bert-base-cased"
" /tmp/transformers/bert-glue --model_name_or_path bert-base-cased"
f" --overwrite_output_dir --task_name {task_name} --num_train_epochs"
f" {num_train_epochs} --logging_dir ./bert-glue {extraFlags}"
),
) + extra_run_cmds


def get_flax_bert_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
time_out_in_min: int,
task_name: str,
num_train_epochs: int = 1,
is_tpu_reserved: bool = True,
extraFlags: str = "",
) -> task.TpuQueuedResourceTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=PROJECT_NAME,
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)

set_up_cmds = get_flax_bert_setup_cmds()
run_model_cmds = get_flax_bert_run_model_cmds(task_name, num_train_epochs, extraFlags)

job_test_config = test_config.TpuVmTest(
test_config.Tpu(
version=tpu_version,
Expand All @@ -454,6 +533,62 @@ def get_flax_bert_config(
)


def get_flax_bert_conv_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
time_out_in_min: int,
task_name: str,
num_train_epochs: int = 1,
is_tpu_reserved: bool = True,
extraFlags: str = "",
) -> task.TpuQueuedResourceTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=PROJECT_NAME,
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)

set_up_cmds = get_flax_bert_setup_cmds()
work_dir = "/tmp/transformers/bert-glue"
tf_summary_location = f"{work_dir}/events.out.tfevents.flax-bert.v2"
gcs_location = f"{gcs_bucket.XLML_OUTPUT_DIR}/flax/bert/{task_name}/{RUN_DATE}/events.out.tfevents.flax-bert.v2"
extra_run_cmds = (
f"cp {work_dir}/events.out.tfevents.* {tf_summary_location} || exit 0",
f"gsutil cp {tf_summary_location} {gcs_location} || exit 0",
)
run_model_cmds = get_flax_bert_run_model_cmds(
task_name, num_train_epochs, extraFlags, extra_run_cmds
)

job_test_config = test_config.TpuVmTest(
test_config.Tpu(
version=tpu_version,
cores=tpu_cores,
runtime_version=RUNTIME_IMAGE,
reserved=is_tpu_reserved,
),
test_name=f"flax_bert_{task_name}_conv",
set_up_cmds=set_up_cmds,
run_model_cmds=run_model_cmds,
time_out_in_min=time_out_in_min,
task_owner=test_owner.SHIVA_S,
)

job_metric_config = metric_config.MetricConfig(
tensorboard_summary=metric_config.SummaryConfig(
file_location=gcs_location,
aggregation_strategy=metric_config.AggregationStrategy.LAST,
)
)

return task.TpuQueuedResourceTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
task_metric_config=job_metric_config,
)


def get_flax_wmt_config(
tpu_version: TpuVersion,
tpu_cores: int,
Expand Down
53 changes: 34 additions & 19 deletions dags/xlml/solutionsteam_flax_latest_supported.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
jax_vit_conv_extra_flags = jax_vit_func_extra_flags + [
"--model_name_or_path google/vit-base-patch16-224-in21k",
]
jax_vit_v4_32 = flax_config.get_flax_vit_conv_config(
jax_vit_conv_v4_32 = flax_config.get_flax_vit_conv_config(
tpu_version=TpuVersion.V4,
tpu_cores=32,
tpu_zone=Zone.US_CENTRAL2_B.value,
Expand Down Expand Up @@ -232,31 +232,43 @@
tpu_cores=8,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
num_train_epochs=3,
extraFlags=" ".join(jax_bart_v4_8_extra_flags),
).run()

jax_bart_v4_32_extra_flags = [
"--per_device_train_batch_size=32",
"--per_device_eval_batch_size=32",
]
jax_bart_v4_32 = flax_config.get_flax_bart_config(
jax_bart_conv_v4_32 = flax_config.get_flax_bart_conv_config(
tpu_version=TpuVersion.V4,
tpu_cores=32,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
num_train_epochs=30,
extraFlags=" ".join(jax_bart_v4_32_extra_flags),
is_tpu_reserved=False,
).run()

# BERT
jax_bert_v4_batch_size = [
"--per_device_train_batch_size=8",
"--per_device_eval_batch_size=8",
]
jax_bert_conv_extra_flags = [
"--learning_rate 2e-5",
"--eval_steps 500",
]

jax_bert_mnli_extra_flags = [
"--max_seq_length 512",
"--eval_steps 1000",
]
jax_bert_v4_mnli_extra_flags = jax_bert_mnli_extra_flags + [
"--per_device_train_batch_size=8",
"--per_device_eval_batch_size=8",
]
jax_bert_v4_mnli_extra_flags = jax_bert_mnli_extra_flags + jax_bert_v4_batch_size
jax_bert_v4_mnli_conv_extra_flags = (
jax_bert_mnli_extra_flags + jax_bert_v4_batch_size + jax_bert_conv_extra_flags
)

jax_bert_mnli_v4_8 = flax_config.get_flax_bert_config(
tpu_version=TpuVersion.V4,
tpu_cores=8,
Expand All @@ -266,24 +278,26 @@
extraFlags=" ".join(jax_bert_v4_mnli_extra_flags),
).run()

jax_bert_mnli_v4_32 = flax_config.get_flax_bert_config(
jax_bert_mnli_conv_v4_32 = flax_config.get_flax_bert_conv_config(
tpu_version=TpuVersion.V4,
tpu_cores=32,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
time_out_in_min=120,
task_name="mnli",
extraFlags=" ".join(jax_bert_v4_mnli_extra_flags),
num_train_epochs=3,
extraFlags=" ".join(jax_bert_v4_mnli_conv_extra_flags),
is_tpu_reserved=False,
).run()

jax_bert_mrpc_extra_flags = [
"--max_seq_length 128",
"--eval_steps 100",
]
jax_bert_v4_mrpc_extra_flags = jax_bert_mrpc_extra_flags + [
"--per_device_train_batch_size=8",
"--per_device_eval_batch_size=8",
]
jax_bert_v4_mrpc_extra_flags = jax_bert_mrpc_extra_flags + jax_bert_v4_batch_size
jax_bert_v4_mrpc_conv_extra_flags = (
jax_bert_mrpc_extra_flags + jax_bert_v4_batch_size + jax_bert_conv_extra_flags
)

jax_bert_mrpc_v4_8 = flax_config.get_flax_bert_config(
tpu_version=TpuVersion.V4,
tpu_cores=8,
Expand All @@ -293,13 +307,14 @@
extraFlags=" ".join(jax_bert_v4_mrpc_extra_flags),
).run()

jax_bert_mrpc_v4_32 = flax_config.get_flax_bert_config(
jax_bert_mrpc_conv_v4_32 = flax_config.get_flax_bert_conv_config(
tpu_version=TpuVersion.V4,
tpu_cores=32,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
task_name="mrpc",
extraFlags=" ".join(jax_bert_v4_mrpc_extra_flags),
num_train_epochs=3,
extraFlags=" ".join(jax_bert_v4_mrpc_conv_extra_flags),
is_tpu_reserved=False,
).run()

Expand All @@ -318,13 +333,13 @@
jax_resnet_v4_8 >> jax_resnet_v4_32
jax_resnet_v5e_4 >> jax_resnet_v5e_16
jax_resnet_v5p_8 >> jax_resnet_v5p_32
jax_vit_v4_8 >> jax_vit_v4_32
jax_vit_v4_8 >> jax_vit_conv_v4_32
jax_vit_v5e_4
jax_gpt2_v4_8 >> jax_gpt2_v4_32
jax_gpt2_v5e_4
jax_sd_v4_8 >> jax_sd_v4_32
jax_sd_v5e_4
jax_bart_v4_8 >> jax_bart_v4_32
jax_bert_mnli_v4_8 >> jax_bert_mnli_v4_32
jax_bert_mrpc_v4_8 >> jax_bert_mrpc_v4_32
jax_bart_v4_8 >> jax_bart_conv_v4_32
jax_bert_mnli_v4_8 >> jax_bert_mnli_conv_v4_32
jax_bert_mrpc_v4_8 >> jax_bert_mrpc_conv_v4_32
jax_wmt_v4_8

0 comments on commit 24c2ff9

Please sign in to comment.