diff --git a/tensor2tensor/utils/trainer_lib.py b/tensor2tensor/utils/trainer_lib.py index cfd516c4f..8f758fa0c 100644 --- a/tensor2tensor/utils/trainer_lib.py +++ b/tensor2tensor/utils/trainer_lib.py @@ -178,7 +178,8 @@ def create_run_config(model_name, log_step_count_steps=100, intra_op_parallelism_threads=0, tpu_config_extra_kwargs=None, - cloud_tpu_name=""): + cloud_tpu_name="", + cloud_tpu_zone=None): """Create RunConfig, TPUConfig, and Parallelism object.""" session_config = create_session_config( log_device_placement=log_device_placement, @@ -229,7 +230,7 @@ def create_run_config(model_name, # Update run_config to use cluster instead of master/evaluation_master # as we need the cluster spec to use Cloud Pods tpu_cluster_resolver = contrib.cluster_resolver().TPUClusterResolver( - cloud_tpu_name) + tpu=cloud_tpu_name, zone=cloud_tpu_zone) run_config_args["cluster"] = tpu_cluster_resolver del run_config_args["master"] del run_config_args["evaluation_master"]