Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
gunjanj007 committed Dec 17, 2024
1 parent 464637d commit fe91242
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 12 deletions.
4 changes: 0 additions & 4 deletions dags/map_reproducibility/benchmarkdb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def validate_model_id(model_id: str, is_test: bool = False) -> bool:
return False
return True


def validate_hardware_id(hardware_id: str, is_test: bool = False) -> bool:
"""Validates a hardware ID against the hardware_info table."""
id_val = _validate_id(
Expand All @@ -197,7 +196,6 @@ def validate_hardware_id(hardware_id: str, is_test: bool = False) -> bool:
return False
return True


def validate_software_id(software_id: str, is_test: bool = False) -> bool:
"""Validates a software ID against the software_info table."""
id_val = _validate_id(
Expand All @@ -213,15 +211,13 @@ def validate_software_id(software_id: str, is_test: bool = False) -> bool:
return False
return True


print(model_id)

if (
validate_model_id(model_id, is_test)
and validate_hardware_id(hardware_id, is_test)
and validate_software_id(software_id, is_test)
):

summary = workload_benchmark_v2_schema.WorkloadBenchmarkV2Schema(
run_id=f"run-{uuid.uuid4()}",
model_id=model_id,
Expand Down
4 changes: 1 addition & 3 deletions dags/map_reproducibility/nemo_gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def run_aotc_workload():
print(
f"batch size: {global_batch_size}, number of nodes: {number_of_nodes}"
)
average_step_time, mfu = get_metrics(
python_base_path
)
average_step_time, mfu = get_metrics(python_base_path)
model_id = "gpt3-175b"
hardware_id = "a3mega"
software_id = "pytorch_nemo"
Expand Down
9 changes: 4 additions & 5 deletions dags/map_reproducibility/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def namespace_cmds():
return namespace


def helm_install_cmds():
def helm_apply_cmds():
helm_cmds = (
" helm install -f values.yaml "
"--namespace default "
Expand Down Expand Up @@ -130,7 +130,7 @@ def copy_bucket_cmds():

def get_metrics_cmds():
# TODO(gunjanj007): get these parameters from the recipe
get_metrics = (
cmds = (
# "METRICS_FILE=$COMPLETE_JOB_NAME/metrics.txt",
"METRICS_FILE=metrics.txt",
"python3 process_training_results.py --file"
Expand All @@ -142,7 +142,7 @@ def get_metrics_cmds():
"gsutil cp - $METRICS_FILE",
'echo "METRICS_FILE=${METRICS_FILE}"',
)
return get_metrics
return cmds


def get_aotc_repo():
Expand Down Expand Up @@ -170,7 +170,6 @@ def cleanup_cmds():


def get_metrics(metrics_path):

file_content = ""
with open(metrics_path + "/metrics.txt", "r", encoding="utf-8") as file:
file_content = file.read()
Expand All @@ -195,7 +194,7 @@ def extract_python_path(last_line):
return python_path, python_path_to_bq_writer


def extract_gpus(tmpdir, yaml_file, config_path):
def extract_run_details(tmpdir, yaml_file, config_path):
gpus = None
batch_size = None
optimizer = None
Expand Down

0 comments on commit fe91242

Please sign in to comment.