Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BQ writing feature #507

Merged
merged 35 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
9a21b7d
clone aotc and get metrics from gcs
gunjanj007 Dec 6, 2024
970b9ca
reformat
gunjanj007 Dec 6, 2024
5bdcef5
reformat
gunjanj007 Dec 6, 2024
257c148
reformat
gunjanj007 Dec 6, 2024
fa6219a
reformat
gunjanj007 Dec 6, 2024
443d369
reformat
gunjanj007 Dec 6, 2024
b10cee0
minor fix
gunjanj007 Dec 6, 2024
e1d5573
reformat
gunjanj007 Dec 6, 2024
fae639b
merge the directory of hook and python task
gunjanj007 Dec 6, 2024
b305ecc
reformat
gunjanj007 Dec 6, 2024
697e759
reformat
gunjanj007 Dec 6, 2024
1d811ae
reformat
gunjanj007 Dec 6, 2024
0a3c5f7
reformat
gunjanj007 Dec 7, 2024
4ef3684
reformat
gunjanj007 Dec 8, 2024
bfdb497
reformat
gunjanj007 Dec 9, 2024
19c144f
resolve comments
gunjanj007 Dec 9, 2024
db99876
resolve comments
gunjanj007 Dec 10, 2024
d1dbb73
resolve comments
gunjanj007 Dec 10, 2024
24b0cc2
reformat
gunjanj007 Dec 10, 2024
757e6ee
Add Dan and Di as owners for aotc
gunjanj007 Dec 11, 2024
b19ddf9
sync merge branch 'master' into gunjanj
gunjanj007 Dec 12, 2024
9402023
Add bq writing logic
gunjanj007 Dec 12, 2024
31ee221
Merge branch 'master' of https://github.com/GoogleCloudPlatform/ml-au…
gunjanj007 Dec 16, 2024
0dca000
Add bq writer
gunjanj007 Dec 16, 2024
1b3c39f
format fix
gunjanj007 Dec 16, 2024
30eaed5
fix fields name
gunjanj007 Dec 16, 2024
b69b9d9
clean code
gunjanj007 Dec 17, 2024
93470fd
clean code
gunjanj007 Dec 17, 2024
77b981d
clean code
gunjanj007 Dec 17, 2024
464637d
clean code
gunjanj007 Dec 17, 2024
fe91242
fix format
gunjanj007 Dec 17, 2024
6b5724b
fix format
gunjanj007 Dec 17, 2024
39c86ad
resolve comments
gunjanj007 Dec 17, 2024
dbebc8b
resolve comments
gunjanj007 Dec 17, 2024
8b3d992
Merge branch 'master' into gunjanj
gunjanj007 Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions dags/map_reproducibility/benchmarkdb_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"Bash helper commands for AOTC artifacts"
import sys
import os
import getpass


def write_run(
model_id: str,
hardware_id: str,
software_id: str,
number_of_nodes: int,
number_of_chips: int,
container_image_name: str,
global_batch_size: int,
precision: str,
optimizer: str,
seq_length: int,
median_step_time: float,
e2e_time: float,
number_of_steps: int,
mfu: float,
tokens_per_second: float,
writer_path: str,
run_success: bool = True, # True because if mfu is none, writing to db will fail anyway.
run_type: str = "perf_regression",
run_release_status: str = "local",
other_metrics_in_json: str = "",
nccl_driver_nickname: str = None,
env_variables: str = "",
framework_config_in_json: str = "",
xla_flags: str = "",
topology: str = "",
dataset: str = "",
num_of_superblock: int = None,
update_person_ldap: str = getpass.getuser(),
comment: str = "",
is_test: bool = False,
):
"""Writes a workload benchmark run manually to the database.

This function validates the provided IDs and, if valid, constructs a
WorkloadBenchmarkV2Schema object with the given data and writes it to the
"run_summary" table in BigQuery.

Args:
model_id: The ID of the model used in the run.
hardware_id: The ID of the hardware used in the run.
software_id: The ID of the software stack used in the run.
number_of_nodes: The number of nodes used in the run.
number_of_chips: The number of chips used in the run.
container_image_name: The name of the container image used in the run.
global_batch_size: The global batch size used in the run.
precision: The precision used in the run (e.g., fp32, bf16).
optimizer: The optimizer used in the run (e.g., adam, sgd).
seq_length: The sequence length used in the run.
median_step_time: The median step time of the run.
e2e_time: The end-to-end time of the run.
number_of_steps: The number of steps taken in the run.
mfu: The MFU (model flops utilization) achieved in the run.
tokens_per_second: The tokens per second achieved in the run.
run_type: The type of run (default: "perf_optimization").
run_release_status: possible values "local" ( code changes are done locally), "prep_release" ( all code code changes are present in the image)
other_metrics_in_json: A JSON string containing other metrics.
nccl_driver_nickname: The nickname of the NCCL driver used.
env_variables: A string containing environment variables.
framework_config_in_json: A JSON string containing framework configurations.
xla_flags: A json string containing all the XLA flags.
topology: The topology of the hardware used in the run. ( valid for TPUs)
dataset: The dataset used in the run.
num_of_superblock: The number of superblocks in the hardware. ( valid for GPUs)
update_person_ldap: The LDAP ID of the person updating the record (default: current user).
comment: A comment about the run.
is_test: Whether to use the testing project or the production project.

Raises:
ValueError: If any of the IDs are invalid.
"""

sys.path.append(writer_path)

# pylint: disable=import-outside-toplevel
import logging
import uuid
from typing import Type

from aotc.benchmark_db_writer import bq_writer_utils
from aotc.benchmark_db_writer.schema.workload_benchmark_v2 import workload_benchmark_v2_schema
from aotc.benchmark_db_writer.schema.workload_benchmark_v2 import model_info_schema
from aotc.benchmark_db_writer.schema.workload_benchmark_v2 import software_info_schema
from aotc.benchmark_db_writer.schema.workload_benchmark_v2 import hardware_info_schema
# pylint: enable=import-outside-toplevel
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)

def get_db_client(
table: str, dataclass_type: Type, is_test: bool = False
) -> bq_writer_utils.create_bq_writer_object:
"""Creates a BigQuery client object.

Args:
table: The name of the BigQuery table.
dataclass_type: The dataclass type corresponding to the table schema.
is_test: Whether to use the testing project or the production project.

Returns:
A BigQuery client object.
"""

project = "supercomputer-testing" if is_test else "ml-workload-benchmarks"
dataset = "mantaray_v2" if is_test else "benchmark_dataset_v2"
return bq_writer_utils.create_bq_writer_object(
project=project,
dataset=dataset,
table=table,
dataclass_type=dataclass_type,
)

def _validate_id(
id_value: str,
table_name: str,
id_field: str,
dataclass_type: Type,
is_test: bool = False,
) -> bool:
"""Generic function to validate an ID against a BigQuery table.

Args:
id_value: The ID value to validate.
table_name: The name of the BigQuery table.
id_field: The name of the ID field in the table.
is_test: Whether to use the testing project or the production project.

Returns:
True if the ID is valid, False otherwise.
"""

client = get_db_client(table_name, dataclass_type, is_test)
result = client.query(where={id_field: id_value})

if not result:
logger.info(
"%s: %s is not present in the %s table ",
id_field.capitalize(),
id_value,
table_name,
)
logger.info(
"Please add %s specific row in %s table before adding to run summary table",
id_value,
table_name,
)
return False
return True

def validate_model_id(model_id: str, is_test: bool = False) -> bool:
"""Validates a model ID against the model_info table."""

print("model id: " + model_id)
id_val = _validate_id(
model_id, "model_info", "model_id", model_info_schema.ModelInfo, is_test
)
if not id_val:
print("model id validation failed")
return False
return True

def validate_hardware_id(hardware_id: str, is_test: bool = False) -> bool:
"""Validates a hardware ID against the hardware_info table."""
id_val = _validate_id(
hardware_id,
"hardware_info",
"hardware_id",
hardware_info_schema.HardwareInfo,
is_test,
)
if not id_val:
print("hardware id validation failed")
return False
return True

def validate_software_id(software_id: str, is_test: bool = False) -> bool:
"""Validates a software ID against the software_info table."""
id_val = _validate_id(
software_id,
"software_info",
"software_id",
software_info_schema.SoftwareInfo,
is_test,
)

if not id_val:
print("software id validation failed")
return False
return True

print(model_id)

if (
validate_model_id(model_id, is_test)
and validate_hardware_id(hardware_id, is_test)
and validate_software_id(software_id, is_test)
):
summary = workload_benchmark_v2_schema.WorkloadBenchmarkV2Schema(
run_id=f"run-{uuid.uuid4()}",
model_id=model_id,
software_id=software_id,
hardware_id=hardware_id,
hardware_num_chips=number_of_chips,
hardware_num_nodes=number_of_nodes,
result_success=run_success,
configs_framework=framework_config_in_json,
configs_env=env_variables,
configs_container_version=container_image_name,
configs_xla_flags=xla_flags,
configs_dataset=dataset,
logs_artifact_directory="",
update_person_ldap=update_person_ldap,
run_source="automation",
run_type=run_type,
run_release_status=run_release_status,
workload_precision=precision,
workload_gbs=global_batch_size,
workload_optimizer=optimizer,
workload_sequence_length=seq_length,
metrics_e2e_time=e2e_time,
metrics_mfu=mfu,
metrics_step_time=median_step_time,
metrics_tokens_per_second=tokens_per_second,
metrics_steps_for_convergence=number_of_steps,
metrics_other=other_metrics_in_json,
hardware_nccl_driver_nickname=nccl_driver_nickname,
hardware_topology=topology,
hardware_num_superblocks=num_of_superblock,
logs_comments=comment,
)

client = get_db_client(
"run_summary",
workload_benchmark_v2_schema.WorkloadBenchmarkV2Schema,
is_test,
)
client.write([summary])

else:
raise ValueError("Could not upload data in run summary table")
76 changes: 66 additions & 10 deletions dags/map_reproducibility/nemo_gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import datetime
import sys
import os
import tempfile

from airflow import models
Expand All @@ -32,16 +33,23 @@
from dags.map_reproducibility.utils import cleanup_cmds
from dags.map_reproducibility.utils import git_cookie_authdaemon
from dags.map_reproducibility.utils import clone_gob
from dags.map_reproducibility.utils import helm_install_cmds
from dags.map_reproducibility.utils import get_metrics_from_gcs
from dags.map_reproducibility.utils import helm_apply_cmds
from dags.map_reproducibility.utils import get_metrics
from dags.map_reproducibility.utils import get_aotc_repo
from dags.map_reproducibility.utils import extract_bucket_file_name
from dags.map_reproducibility.utils import extract_python_path
from dags.map_reproducibility.benchmarkdb_utils import write_run
from dags.map_reproducibility.utils import extract_run_details


# Run once a day at 2 pm UTC (6 am PST)
SCHEDULED_TIME = "0 14 * * *" if composer_env.is_prod_env() else None

MODEL_ID = "gpt3-175b"
BATCH_SIZE = 2048
NUM_ACCELERATORS = 256
PRECISION = "fp8"
ACCELERATOR_TYPE = "h100"


@task
def run_aotc_workload():
Expand All @@ -60,6 +68,7 @@ def run_aotc_workload():

with tempfile.TemporaryDirectory() as tmpdir:
hook = SubprocessHook()
# TODO(gunjanjalori): clone recipe first and extract params
result = hook.run_command(
[
"bash",
Expand All @@ -73,10 +82,16 @@ def run_aotc_workload():
+ install_helm_cmds()
+ namespace_cmds()
+ workload_cmds
+ helm_install_cmds()
+ helm_apply_cmds()
+ wait_for_jobs_cmds()
+ copy_bucket_cmds()
+ get_metrics_cmds()
+ get_metrics_cmds(
BATCH_SIZE,
NUM_ACCELERATORS,
PRECISION,
MODEL_ID,
ACCELERATOR_TYPE,
)
+ cleanup_cmds()
+ get_aotc_repo()
),
Expand All @@ -85,13 +100,54 @@ def run_aotc_workload():
)
assert result.exit_code == 0, f"Command failed with code {result.exit_code}"

# Extract COMPLETE_JOB_NAME from the output
bucket_name, file_name, python_path = extract_bucket_file_name(
result.output
python_base_path, python_path_to_bq_writer = extract_python_path(
result.output.splitlines()[-1]
)
get_metrics_from_gcs(bucket_name, file_name)
print(f"Base path in python: {python_base_path}")
print(f"python to bq: {python_path_to_bq_writer}")

value_yaml_path = "reproducible-benchmark-recipes/projects/gpu-recipes/training/a3mega/gpt3-175b/nemo-pretraining-gke/values.yaml"
config_yaml_path = "reproducible-benchmark-recipes/projects/gpu-recipes/src/frameworks/a3mega/nemo-configs/gpt3-175b-256gpus-fp8.yaml"

sys.path.append(python_path)
(
number_of_nodes,
global_batch_size,
optimizer,
precision,
seq_length,
max_steps,
) = extract_run_details(tmpdir, value_yaml_path, config_yaml_path)
print(
f"batch size: {global_batch_size}, number of nodes: {number_of_nodes}"
)
average_step_time, mfu = get_metrics(python_base_path)
model_id = "gpt3-175b"
hardware_id = "a3mega"
software_id = "pytorch_nemo"
image_version = "nemo_workload:24.07"
number_of_chips = number_of_nodes * 8

write_run(
model_id=model_id,
hardware_id=hardware_id,
software_id=software_id,
number_of_nodes=number_of_nodes,
number_of_chips=number_of_chips,
container_image_name=image_version,
global_batch_size=global_batch_size,
precision=precision,
optimizer=optimizer,
seq_length=seq_length,
median_step_time=average_step_time,
e2e_time=0,
number_of_steps=max_steps,
mfu=mfu,
tokens_per_second=1,
writer_path=python_path_to_bq_writer,
topology="2X2",
comment="Regression tests",
is_test=True,
)


with models.DAG(
Expand Down
Loading
Loading