diff --git a/xlml/utils/xpk.py b/xlml/utils/xpk.py index d3d8e4828..b2b17eea4 100644 --- a/xlml/utils/xpk.py +++ b/xlml/utils/xpk.py @@ -59,8 +59,6 @@ def run_workload( ): """Run workload through xpk tool.""" - run_cmds = f"export {metric_config.SshEnvVars.GCS_OUTPUT.name}={gcs_path}; {run_cmds}" - with tempfile.TemporaryDirectory() as tmpdir: cmds = ( "set -xue", @@ -71,6 +69,7 @@ def run_workload( f" --command='{run_cmds}' --device-type={accelerator_type}" f" --num-slices={num_slices} --docker-image={docker_image}" f" --project={cluster_project} --zone={zone}" + f" --env {metric_config.SshEnvVars.GCS_OUTPUT.name}={gcs_path}" ), ) hook = SubprocessHook() @@ -142,14 +141,16 @@ def wait_for_workload_completion( elif pod.status.phase in ["Unknown"]: raise RuntimeError(f"Bad pod phase: {pod.status.phase}") finally: - # Print the logs of the last pod checked - either the first failed pod or - # the last successful one. - logs = core_api.read_namespaced_pod_log( - name=pod.metadata.name, namespace=pod.metadata.namespace - ) - logging.info(f"Logs for pod {pod.metadata.name}:") - for line in logs.split("\n"): - logging.info(line) + # TODO(jonbolin): log printing for GPUs, which have multiple containers + if len(pod.spec.containers) == 1: + # Print the logs of the last pod checked - either the first failed pod or + # the last successful one. + logs = core_api.read_namespaced_pod_log( + name=pod.metadata.name, namespace=pod.metadata.namespace + ) + logging.info(f"Logs for pod {pod.metadata.name}:") + for line in logs.split("\n"): + logging.info(line) url = WORKLOAD_URL_FORMAT.format( region=region, cluster=cluster_name,