Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
gunjanj007 committed Dec 6, 2024
1 parent fa6219a commit 443d369
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions dags/map_reproducibility/aotc_reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import re
from google.cloud import storage


def set_variables_cmds():
set_variables = (
# "set -e",
Expand All @@ -38,6 +39,7 @@ def configure_project_and_cluster():
)
return set_project_command


def git_cookie_authdaemon():
auth_cmds = (
"git clone https://gerrit.googlesource.com/gcompute-tools",
Expand All @@ -63,6 +65,7 @@ def clone_gob():
)
return gob_clone_cmds


def stop_git_daemon():
cmd = (
"git config --global --unset credential.helper",
Expand Down Expand Up @@ -130,6 +133,7 @@ def copy_bucket_cmds():
)
return copy_bucket_contents


def get_metrics_cmds():
# TODO(gunjanj007): get these parameters from the recipe
get_metrics = (
Expand All @@ -145,6 +149,7 @@ def get_metrics_cmds():
)
return get_metrics


def get_aotc_repo():
gob_clone_cmds = (
"echo 'trying to clone GoB aotc repo'",
Expand All @@ -156,6 +161,7 @@ def get_aotc_repo():
)
return gob_clone_cmds


def cleanup_cmds():
cleanup = (
"cd $REPO_ROOT",
Expand All @@ -167,6 +173,7 @@ def cleanup_cmds():
)
return cleanup


def get_metrics_from_gcs(bucket_name, file_name):
# bucket_name = 'gunjanjalori-testing-xlml'
# file_name = 'nemo-experiments/gpt3-xlml-1731373474-175b-nemo-1731373494-ic5n/metrics.txt'
Expand All @@ -193,15 +200,15 @@ def get_metrics_from_gcs(bucket_name, file_name):


def extract_bucket_file_name(bash_result_output):
complete_job_name = None
metrics_file = None
for line in bash_result_output.splitlines():
if line.startswith("COMPLETE_JOB_NAME="):
complete_job_name = line.split("=", 1)[1]
if line.startswith("METRICS_FILE="):
metrics_file = line.split("=", 1)[1]
break
if complete_job_name:
if metrics_file:
# Extract bucket_name and file_name
bucket_name = re.search(r"gs://([^/]+)/", complete_job_name).group(1)
file_name = re.search(r"gs://[^/]+/(.+)", complete_job_name).group(1)
bucket_name = re.search(r"gs://([^/]+)/", metrics_file).group(1)
file_name = re.search(r"gs://[^/]+/(.+)", metrics_file).group(1)

print(f"Bucket name: {bucket_name}")
print(f"File name: {file_name}")
Expand All @@ -219,5 +226,3 @@ def extract_python_path(bash_result_output):
print(f"Pyhon path name: {python_path}")

return python_path


0 comments on commit 443d369

Please sign in to comment.