From 1f99e87fde4e0a5dfff14c70caf5a8f9dd73c46c Mon Sep 17 00:00:00 2001 From: Priyanka Ganesha Date: Fri, 11 Aug 2023 09:19:21 -0700 Subject: [PATCH 01/10] metrics --- MaxText/emit_metrics.py | 111 ++++++++++++++++++++++++++++++++++++++++ requirements.txt | 2 + 2 files changed, 113 insertions(+) create mode 100644 MaxText/emit_metrics.py diff --git a/MaxText/emit_metrics.py b/MaxText/emit_metrics.py new file mode 100644 index 000000000..69029ba3c --- /dev/null +++ b/MaxText/emit_metrics.py @@ -0,0 +1,111 @@ +from google.cloud import monitoring_v3 +from google.cloud import compute_v1 +from google.api import metric_pb2 +import time +import os + +PROJECT_ID="cloud-tpu-multipod-dev" +ZONE = "us-central2-b" + +def create_custom_metric(metric_name, description): + metric_name = "checkpoint_init" + project_name = f"projects/{PROJECT_ID}" + + client = monitoring_v3.MetricServiceClient() + + descriptor = metric_pb2.MetricDescriptor() + descriptor.type = "custom.googleapis.com/" + metric_name + descriptor.metric_kind = metric_pb2.MetricDescriptor.MetricKind.GAUGE + descriptor.value_type = metric_pb2.MetricDescriptor.ValueType.DOUBLE + descriptor.description = description + + request = monitoring_v3.CreateMetricDescriptorRequest( + name=project_name, + metric_descriptor=descriptor + ) + + response = client.create_metric_descriptor(request=request) + + print(response) + + +def write_time_series_step(metric_name, step, status): + """ + Emits a data point when a training step is STARTED and COMPLETED. + Args: + metric_name: name of the metric + step: training step + status: STARTED if the training step is started, else COMPLETED + """ + client = get_metrics_service_client() + project_name = f"projects/{PROJECT_ID}" + + seconds_since_epoch_utc = time.time() + nanos_since_epoch_utc = int( + (seconds_since_epoch_utc - int(seconds_since_epoch_utc)) * 10**9 + ) + interval = monitoring_v3.types.TimeInterval( + { + "end_time": { + "seconds": int(seconds_since_epoch_utc), + "nanos": nanos_since_epoch_utc, + } + } + ) + + event_time = time.strftime( + "%d %b %Y %H:%M:%S UTC", time.gmtime(seconds_since_epoch_utc) + ) + print( + "Emitting status = ", + status, + " for step = ", + step, + " at: ", + event_time, + ) + + instance_id = get_instance_id() + + series = monitoring_v3.types.TimeSeries() + series.metric.type = "custom.googleapis.com/" + metric_name + series.resource.type = "gce_instance" + series.resource.labels["instance_id"] = str(instance_id) + series.resource.labels["zone"] = ZONE + series.metric.labels["step_num"] = str(step) + series.metric.labels["worker"] = os.uname().nodename + series.metric.labels["status"] = status + series.metric.labels["event_time"] = event_time + series.points = [ + monitoring_v3.types.Point( + interval=interval, + value=monitoring_v3.types.TypedValue( + double_value=step + ), + ) + ] + + client.create_time_series(name=project_name, time_series=[series]) + print( + "Time series added for step", + step, + "and instance_id ", + instance_id, + " and zone ", + ZONE, + ) + +def get_instance_id(): + client = get_compute_instances_client() + instance_name = os.uname().nodename + instance = client.get(project=PROJECT_ID, zone=ZONE, instance=instance_name) + return instance.id + +def get_compute_instances_client(): + return compute_v1.InstancesClient() + +def get_metrics_service_client(): + return monitoring_v3.MetricServiceClient() + +if __name__ == "__main__": + create_custom_metric('checkpointing_init', get_metrics_service_client(), "This is a checkpointing init metric.") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 68db8280b..40851fa81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,8 @@ absl-py argparse cloud-tpu-diagnostics datetime +google-cloud-compute==1.6.1 +google-cloud-monitoring==2.11.3 google-cloud-storage flax ml-collections From ec2a1772a8009c055170de1392f53444987ec29c Mon Sep 17 00:00:00 2001 From: Priyanka Ganesha Date: Fri, 22 Sep 2023 07:30:18 -0700 Subject: [PATCH 02/10] cloud monitoring prototype and checkpoint initialization metrics emitting --- MaxText/configs/base.yml | 3 ++ MaxText/emit_metrics.py | 67 ++++++++++++++++++++++++---------------- MaxText/train.py | 19 ++++++++++++ 3 files changed, 63 insertions(+), 26 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 5632c0318..f443b0baa 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -175,3 +175,6 @@ stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds # Use iota operator in Embed use_iota_embed: False + +#Monitoring parameters +enable_cloud_monitoring: True diff --git a/MaxText/emit_metrics.py b/MaxText/emit_metrics.py index 69029ba3c..83c00b578 100644 --- a/MaxText/emit_metrics.py +++ b/MaxText/emit_metrics.py @@ -1,15 +1,20 @@ +import subprocess +import sys from google.cloud import monitoring_v3 from google.cloud import compute_v1 from google.api import metric_pb2 +import requests import time import os -PROJECT_ID="cloud-tpu-multipod-dev" -ZONE = "us-central2-b" +def get_metadata(project_id, zone, instance_id): + r = requests.get(url="https://compute.googleapis.com/compute/v1/projects/{project_id}/zones/{zone}/instances/{instance_id}") + metadata = r.json() + return metadata def create_custom_metric(metric_name, description): - metric_name = "checkpoint_init" - project_name = f"projects/{PROJECT_ID}" + project_id = get_project() + project_name = f"projects/{project_id}" client = monitoring_v3.MetricServiceClient() @@ -29,16 +34,16 @@ def create_custom_metric(metric_name, description): print(response) -def write_time_series_step(metric_name, step, status): - """ - Emits a data point when a training step is STARTED and COMPLETED. - Args: - metric_name: name of the metric - step: training step - status: STARTED if the training step is started, else COMPLETED - """ +def write_time_series_step(metric_name, monitoring_enabled, step=1): + + zone = get_zone() + project_id = get_project() + + if not monitoring_enabled: + return + client = get_metrics_service_client() - project_name = f"projects/{PROJECT_ID}" + project_name = f"projects/{project_id}" seconds_since_epoch_utc = time.time() nanos_since_epoch_utc = int( @@ -57,24 +62,23 @@ def write_time_series_step(metric_name, step, status): "%d %b %Y %H:%M:%S UTC", time.gmtime(seconds_since_epoch_utc) ) print( - "Emitting status = ", - status, + "Emitting metric ", + metric_name, " for step = ", step, " at: ", event_time, ) - instance_id = get_instance_id() + instance_id = get_instance_id(project_id, zone) series = monitoring_v3.types.TimeSeries() series.metric.type = "custom.googleapis.com/" + metric_name series.resource.type = "gce_instance" series.resource.labels["instance_id"] = str(instance_id) - series.resource.labels["zone"] = ZONE + series.resource.labels["zone"] = zone series.metric.labels["step_num"] = str(step) series.metric.labels["worker"] = os.uname().nodename - series.metric.labels["status"] = status series.metric.labels["event_time"] = event_time series.points = [ monitoring_v3.types.Point( @@ -85,27 +89,38 @@ def write_time_series_step(metric_name, step, status): ) ] - client.create_time_series(name=project_name, time_series=[series]) + client.create_time_series(name=project_name, time_series=[series], metadata=get_metadata(project_id, zone, instance_id)) print( "Time series added for step", step, "and instance_id ", instance_id, " and zone ", - ZONE, + zone, ) -def get_instance_id(): +def get_instance_id(project_id, zone): client = get_compute_instances_client() instance_name = os.uname().nodename - instance = client.get(project=PROJECT_ID, zone=ZONE, instance=instance_name) + instance = client.get(project=project_id, zone=zone, instance=instance_name) return instance.id +def get_project(): + completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) + project_outputs = completed_command.stdout.decode().strip().split('\n') + if len(project_outputs) < 1 or project_outputs[-1]=='': + sys.exit("You must specify the project in the PROJECT flag or set it with 'gcloud config set project '") + return project_outputs[-1] + +def get_zone(): + completed_command = subprocess.run(["gcloud", "config", "get", "compute/zone"], check=True, capture_output=True) + zone_outputs = completed_command.stdout.decode().strip().split('\n') + if len(zone_outputs) < 1 or zone_outputs[-1]=='': + sys.exit("You must specify the zone in the ZONE flag or set it with 'gcloud config set compute/zone '") + return zone_outputs[-1] + def get_compute_instances_client(): return compute_v1.InstancesClient() def get_metrics_service_client(): - return monitoring_v3.MetricServiceClient() - -if __name__ == "__main__": - create_custom_metric('checkpointing_init', get_metrics_service_client(), "This is a checkpointing init metric.") \ No newline at end of file + return monitoring_v3.MetricServiceClient() \ No newline at end of file diff --git a/MaxText/train.py b/MaxText/train.py index 53ae2420a..002e42ca1 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -56,6 +56,7 @@ from cloud_tpu_diagnostics.configuration import stack_trace_configuration import max_logging +import emit_metrics cc.initialize_cache(os.path.expanduser("~/jax_cache")) # https://arxiv.org/pdf/2204.02311.pdf Appendix B @@ -211,13 +212,30 @@ def train_loop(config, state=None): Returns: """ + + monitoring_enabled = config.enable_cloud_monitoring + + if monitoring_enabled: + emit_metrics.create_custom_metric('checkpointing_init_start', "Checkpointing Initialization Start") + emit_metrics.create_custom_metric('checkpointing_init_end', "Checkpointing Initialization End") + emit_metrics.create_custom_metric('checkpoint_test_run_start', "Checkpointing Test Run Start") + emit_metrics.create_custom_metric('checkpoint_test_run_end', "Checkpointing Test Run End") + + emit_metrics.write_time_series_step('checkpoint_test_run_start', 0, monitoring_enabled) + writer = SummaryWriter(config.tensorboard_dir) + + emit_metrics.write_time_series_step('checkpointing_init_start', 1, monitoring_enabled) + checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( config.checkpoint_dir, config.enable_checkpointing, config.async_checkpointing, config.save_period, ) + + emit_metrics.write_time_series_step('checkpointing_init_end', 1, monitoring_enabled) + # Initial PRNG Keys init_rng, nextrng = random.split(random.PRNGKey(config.init_weights_seed), 2) @@ -300,6 +318,7 @@ def train_loop(config, state=None): if step == 0: max_utils.activate_profiler(config) + emit_metrics.write_time_series_step('checkpoint_test_run_end', config.steps, monitoring_enabled) max_utils.deactivate_profiler(config) writer.close() return state From 63265142699d25f82f99350fae245c1498cb266f Mon Sep 17 00:00:00 2001 From: Priyanka Ganesha Date: Fri, 22 Sep 2023 07:58:06 -0700 Subject: [PATCH 03/10] pylint errors --- MaxText/emit_metrics.py | 87 +++++++++++++++++++++++++++++++++-------- MaxText/train.py | 2 +- 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/MaxText/emit_metrics.py b/MaxText/emit_metrics.py index 83c00b578..c7d0f5f07 100644 --- a/MaxText/emit_metrics.py +++ b/MaxText/emit_metrics.py @@ -1,3 +1,7 @@ +""" +Cloud Monitoring API v3 Prototype +""" + import subprocess import sys from google.cloud import monitoring_v3 @@ -8,40 +12,70 @@ import os def get_metadata(project_id, zone, instance_id): - r = requests.get(url="https://compute.googleapis.com/compute/v1/projects/{project_id}/zones/{zone}/instances/{instance_id}") + """ + Fetches metadata + + Args: + project_id + zone + instance_id + + Returns: + metadata as json + """ + r = requests.get(url="https://compute.googleapis.com/compute/v1/projects/\ + {project_id}/zones/{zone}/instances/{instance_id}") metadata = r.json() return metadata def create_custom_metric(metric_name, description): - project_id = get_project() - project_name = f"projects/{project_id}" + """ + Creates a custom metric - client = monitoring_v3.MetricServiceClient() + Args: + metric_name + description + + Returns: + Response from create request + """ + project_id = get_project() + project_name = f"projects/{project_id}" - descriptor = metric_pb2.MetricDescriptor() - descriptor.type = "custom.googleapis.com/" + metric_name - descriptor.metric_kind = metric_pb2.MetricDescriptor.MetricKind.GAUGE - descriptor.value_type = metric_pb2.MetricDescriptor.ValueType.DOUBLE - descriptor.description = description + client = monitoring_v3.MetricServiceClient() - request = monitoring_v3.CreateMetricDescriptorRequest( - name=project_name, - metric_descriptor=descriptor - ) + descriptor = metric_pb2.MetricDescriptor() + descriptor.type = "custom.googleapis.com/" + metric_name + descriptor.metric_kind = metric_pb2.MetricDescriptor.MetricKind.GAUGE + descriptor.value_type = metric_pb2.MetricDescriptor.ValueType.DOUBLE + descriptor.description = description - response = client.create_metric_descriptor(request=request) + request = monitoring_v3.CreateMetricDescriptorRequest( + name=project_name, + metric_descriptor=descriptor + ) - print(response) + response = client.create_metric_descriptor(request=request) + + return response def write_time_series_step(metric_name, monitoring_enabled, step=1): + """ + Writes a time series object for a specified custom metric + + Args: + metric_name + monitoring_enabled + step + """ zone = get_zone() project_id = get_project() if not monitoring_enabled: return - + client = get_metrics_service_client() project_name = f"projects/{project_id}" @@ -100,12 +134,22 @@ def write_time_series_step(metric_name, monitoring_enabled, step=1): ) def get_instance_id(project_id, zone): + """ + Fetches instance id of a node + + Args: + project_id + zone + """ client = get_compute_instances_client() instance_name = os.uname().nodename instance = client.get(project=project_id, zone=zone, instance=instance_name) return instance.id def get_project(): + """ + Fetches id of project in use + """ completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) project_outputs = completed_command.stdout.decode().strip().split('\n') if len(project_outputs) < 1 or project_outputs[-1]=='': @@ -113,6 +157,9 @@ def get_project(): return project_outputs[-1] def get_zone(): + """ + Fetches zone in use + """ completed_command = subprocess.run(["gcloud", "config", "get", "compute/zone"], check=True, capture_output=True) zone_outputs = completed_command.stdout.decode().strip().split('\n') if len(zone_outputs) < 1 or zone_outputs[-1]=='': @@ -120,7 +167,13 @@ def get_zone(): return zone_outputs[-1] def get_compute_instances_client(): + """ + Fetches cloud compute instances client + """ return compute_v1.InstancesClient() def get_metrics_service_client(): - return monitoring_v3.MetricServiceClient() \ No newline at end of file + """ + Fetches cloud monitoring API client + """ + return monitoring_v3.MetricServiceClient() diff --git a/MaxText/train.py b/MaxText/train.py index 002e42ca1..36f1acbc7 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -235,7 +235,7 @@ def train_loop(config, state=None): ) emit_metrics.write_time_series_step('checkpointing_init_end', 1, monitoring_enabled) - + # Initial PRNG Keys init_rng, nextrng = random.split(random.PRNGKey(config.init_weights_seed), 2) From 494ebf6910fffa0a2c83f0a2178e644106c115a2 Mon Sep 17 00:00:00 2001 From: Priyanka Ganesha Date: Fri, 22 Sep 2023 08:02:10 -0700 Subject: [PATCH 04/10] pylint --- MaxText/emit_metrics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MaxText/emit_metrics.py b/MaxText/emit_metrics.py index c7d0f5f07..2b392a78f 100644 --- a/MaxText/emit_metrics.py +++ b/MaxText/emit_metrics.py @@ -1,3 +1,4 @@ +# pylint: disable=unused-argument, no-name-in-module """ Cloud Monitoring API v3 Prototype """ @@ -56,7 +57,7 @@ def create_custom_metric(metric_name, description): ) response = client.create_metric_descriptor(request=request) - + return response From 5ee4fe37768e983c8f7b6a4c76a6a5d020b79e73 Mon Sep 17 00:00:00 2001 From: Priyanka Ganesha Date: Fri, 22 Sep 2023 08:04:28 -0700 Subject: [PATCH 05/10] rename --- MaxText/{emit_metrics.py => monitoring_api.py} | 0 MaxText/train.py | 18 +++++++++--------- 2 files changed, 9 insertions(+), 9 deletions(-) rename MaxText/{emit_metrics.py => monitoring_api.py} (100%) diff --git a/MaxText/emit_metrics.py b/MaxText/monitoring_api.py similarity index 100% rename from MaxText/emit_metrics.py rename to MaxText/monitoring_api.py diff --git a/MaxText/train.py b/MaxText/train.py index 36f1acbc7..51aaf48e2 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -56,7 +56,7 @@ from cloud_tpu_diagnostics.configuration import stack_trace_configuration import max_logging -import emit_metrics +import monitoring_api cc.initialize_cache(os.path.expanduser("~/jax_cache")) # https://arxiv.org/pdf/2204.02311.pdf Appendix B @@ -216,16 +216,16 @@ def train_loop(config, state=None): monitoring_enabled = config.enable_cloud_monitoring if monitoring_enabled: - emit_metrics.create_custom_metric('checkpointing_init_start', "Checkpointing Initialization Start") - emit_metrics.create_custom_metric('checkpointing_init_end', "Checkpointing Initialization End") - emit_metrics.create_custom_metric('checkpoint_test_run_start', "Checkpointing Test Run Start") - emit_metrics.create_custom_metric('checkpoint_test_run_end', "Checkpointing Test Run End") + monitoring_api.create_custom_metric('checkpointing_init_start', "Checkpointing Initialization Start") + monitoring_api.create_custom_metric('checkpointing_init_end', "Checkpointing Initialization End") + monitoring_api.create_custom_metric('checkpoint_test_run_start', "Checkpointing Test Run Start") + monitoring_api.create_custom_metric('checkpoint_test_run_end', "Checkpointing Test Run End") - emit_metrics.write_time_series_step('checkpoint_test_run_start', 0, monitoring_enabled) + monitoring_api.write_time_series_step('checkpoint_test_run_start', 0, monitoring_enabled) writer = SummaryWriter(config.tensorboard_dir) - emit_metrics.write_time_series_step('checkpointing_init_start', 1, monitoring_enabled) + monitoring_api.write_time_series_step('checkpointing_init_start', 1, monitoring_enabled) checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( config.checkpoint_dir, @@ -234,7 +234,7 @@ def train_loop(config, state=None): config.save_period, ) - emit_metrics.write_time_series_step('checkpointing_init_end', 1, monitoring_enabled) + monitoring_api.write_time_series_step('checkpointing_init_end', 1, monitoring_enabled) # Initial PRNG Keys init_rng, nextrng = random.split(random.PRNGKey(config.init_weights_seed), 2) @@ -318,7 +318,7 @@ def train_loop(config, state=None): if step == 0: max_utils.activate_profiler(config) - emit_metrics.write_time_series_step('checkpoint_test_run_end', config.steps, monitoring_enabled) + monitoring_api.write_time_series_step('checkpoint_test_run_end', config.steps, monitoring_enabled) max_utils.deactivate_profiler(config) writer.close() return state From 54caff6749149b514a46cf657964fe68a158bf9f Mon Sep 17 00:00:00 2001 From: Priyanka Ganesha Date: Fri, 22 Sep 2023 08:07:28 -0700 Subject: [PATCH 06/10] rename --- MaxText/monitoring_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/MaxText/monitoring_api.py b/MaxText/monitoring_api.py index 2b392a78f..3f3729b1f 100644 --- a/MaxText/monitoring_api.py +++ b/MaxText/monitoring_api.py @@ -161,6 +161,7 @@ def get_zone(): """ Fetches zone in use """ + subprocess.run("gcloud config set compute/zone us-central2-b") completed_command = subprocess.run(["gcloud", "config", "get", "compute/zone"], check=True, capture_output=True) zone_outputs = completed_command.stdout.decode().strip().split('\n') if len(zone_outputs) < 1 or zone_outputs[-1]=='': From c585140bb505eb2a3a0892ab079ccf9e7c66905c Mon Sep 17 00:00:00 2001 From: Priyanka Ganesha Date: Mon, 2 Oct 2023 13:41:58 -0700 Subject: [PATCH 07/10] changes based on comments --- MaxText/configs/base.yml | 4 +- MaxText/max_utils.py | 4 ++ MaxText/monitoring_api.py | 62 +++++++++++++++++++------- MaxText/pyconfig.py | 3 ++ MaxText/tests/cloud_monitoring_test.py | 39 ++++++++++++++++ MaxText/train.py | 16 +++---- requirements.txt | 4 +- 7 files changed, 106 insertions(+), 26 deletions(-) create mode 100644 MaxText/tests/cloud_monitoring_test.py diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index f443b0baa..2fc059923 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -176,5 +176,7 @@ stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds # Use iota operator in Embed use_iota_embed: False -#Monitoring parameters +#Monitoring parameters - Export in-workload metrics to Cloud monitoring enable_cloud_monitoring: True +cloud_monitoring_dashboard: "https://pantheon.corp.google.com/monitoring/dashboards?project=" +cloud_zone: "" # zone name for cloud jobs - used for cloud metrics emitting diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 02ba6ea30..210bdb392 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -20,6 +20,7 @@ import functools import max_logging +import monitoring_api import numpy as np import jax @@ -227,6 +228,9 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager): state = unbox_logicallypartioned_trainstate(state) return state, state_mesh_annotations +def register_train_metrics(metric_name, metric_description): + monitoring_api.create_custom_metric(metric_name, metric_description) + # Learning Rate Schedule # ----------------------------------------------------------------------------- diff --git a/MaxText/monitoring_api.py b/MaxText/monitoring_api.py index 3f3729b1f..d2c2785df 100644 --- a/MaxText/monitoring_api.py +++ b/MaxText/monitoring_api.py @@ -12,6 +12,8 @@ import time import os +import max_logging + def get_metadata(project_id, zone, instance_id): """ Fetches metadata @@ -61,7 +63,7 @@ def create_custom_metric(metric_name, description): return response -def write_time_series_step(metric_name, monitoring_enabled, step=1): +def write_time_series_step(metric_name, monitoring_enabled, pyconfig, step=1): """ Writes a time series object for a specified custom metric @@ -71,7 +73,7 @@ def write_time_series_step(metric_name, monitoring_enabled, step=1): step """ - zone = get_zone() + zone = pyconfig.config.cloud_zone project_id = get_project() if not monitoring_enabled: @@ -96,7 +98,7 @@ def write_time_series_step(metric_name, monitoring_enabled, step=1): event_time = time.strftime( "%d %b %Y %H:%M:%S UTC", time.gmtime(seconds_since_epoch_utc) ) - print( + max_logging.log( "Emitting metric ", metric_name, " for step = ", @@ -125,14 +127,49 @@ def write_time_series_step(metric_name, monitoring_enabled, step=1): ] client.create_time_series(name=project_name, time_series=[series], metadata=get_metadata(project_id, zone, instance_id)) - print( + dashboard_link = pyconfig.config.cloud_monitoring_dashboard+project_name + max_logging.log( "Time series added for step", step, "and instance_id ", instance_id, " and zone ", zone, + "\nView dashboards or use metrics: ", + dashboard_link, ) + return [series] + +def get_time_series_step_data(metric_name): + """ + Retrieves time series data + + Args: + metric_name + """ + project_id = get_project() + project_name = f"projects/{project_id}" + instance_name = os.uname().nodename + + mql = """ + fetch gce_instance + | metric 'custom.googleapis.com/{metric_name}' + | filter (metric.worker == '{worker_id}') + | every 1m + | within -1d, 1d # one day, starting 1 day ago + """ + + client = get_query_service_client() + request = monitoring_v3.QueryTimeSeriesRequest({ + "name": project_name, + "query": mql.format( + metric_name=metric_name, worker_id=instance_name + ), + }) + + result = client.query_time_series(request) + return result.time_series_data + def get_instance_id(project_id, zone): """ @@ -157,17 +194,6 @@ def get_project(): sys.exit("You must specify the project in the PROJECT flag or set it with 'gcloud config set project '") return project_outputs[-1] -def get_zone(): - """ - Fetches zone in use - """ - subprocess.run("gcloud config set compute/zone us-central2-b") - completed_command = subprocess.run(["gcloud", "config", "get", "compute/zone"], check=True, capture_output=True) - zone_outputs = completed_command.stdout.decode().strip().split('\n') - if len(zone_outputs) < 1 or zone_outputs[-1]=='': - sys.exit("You must specify the zone in the ZONE flag or set it with 'gcloud config set compute/zone '") - return zone_outputs[-1] - def get_compute_instances_client(): """ Fetches cloud compute instances client @@ -179,3 +205,9 @@ def get_metrics_service_client(): Fetches cloud monitoring API client """ return monitoring_v3.MetricServiceClient() + +def get_query_service_client(): + """ + Fetches cloud monitoring query service client + """ + return monitoring_v3.QueryServiceClient() diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 8fcee03b2..7232168f9 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -19,6 +19,7 @@ import math import os +import subprocess import sys import yaml @@ -89,6 +90,8 @@ def user_init(raw_keys): raw_keys["run_name"] = os.environ.get("JOBSET_NAME") #using XPK default run_name = raw_keys["run_name"] assert run_name, "Erroring out, need a real run_name" + assert ((raw_keys['cloud_zone']!="" or not raw_keys['enable_cloud_monitoring']), + "You must provide cloud_zone if cloud monitoring is enabled") base_output_directory = raw_keys["base_output_directory"] validate_gcs_bucket_name(base_output_directory, "base_output_directory") dataset_path = raw_keys["dataset_path"] diff --git a/MaxText/tests/cloud_monitoring_test.py b/MaxText/tests/cloud_monitoring_test.py new file mode 100644 index 000000000..4e2318ab6 --- /dev/null +++ b/MaxText/tests/cloud_monitoring_test.py @@ -0,0 +1,39 @@ +""" + Copyright 2023 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +""" Tests for Cloud Monitoring API """ +import sys +import jax +import unittest + +import monitoring_api +import pyconfig + +jax.config.update('jax_platform_name', 'cpu') + +class CloudMonitoringTests(unittest.TestCase): + """Test for writing time series step using monitoring_api.py""" + def test_write_time_series_step(self): + pyconfig.initialize(sys.argv + ['configs/base.yml'], per_device_batch_size=1, run_name='test', cloud_zone='us-central2-b') + monitoring_api.create_custom_metric('test_metric', "This is an example metric") + create_time_series_result = monitoring_api.write_time_series_step('test_metric', True, pyconfig, 1) + query_time_series_result = monitoring_api.get_time_series_step_data('test_metric') + self.assertEqual(create_time_series_result, query_time_series_result) + + +if __name__ == '__main__': + unittest.main() + diff --git a/MaxText/train.py b/MaxText/train.py index 51aaf48e2..3a3703d54 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -216,16 +216,16 @@ def train_loop(config, state=None): monitoring_enabled = config.enable_cloud_monitoring if monitoring_enabled: - monitoring_api.create_custom_metric('checkpointing_init_start', "Checkpointing Initialization Start") - monitoring_api.create_custom_metric('checkpointing_init_end', "Checkpointing Initialization End") - monitoring_api.create_custom_metric('checkpoint_test_run_start', "Checkpointing Test Run Start") - monitoring_api.create_custom_metric('checkpoint_test_run_end', "Checkpointing Test Run End") + max_utils.register_train_metrics('checkpointint_init_start', "Checkpointing Initialization Start") + max_utils.register_train_metrics('checkpointing_init_end', "Checkpointing Initialization End") + max_utils.register_train_metrics('checkpoint_test_run_start', "Checkpointing Test Run Start") + max_utils.register_train_metrics('checkpoint_test_run_end', "Checkpointing Test Run End") - monitoring_api.write_time_series_step('checkpoint_test_run_start', 0, monitoring_enabled) + monitoring_api.write_time_series_step('checkpoint_test_run_start', monitoring_enabled, pyconfig, 0) writer = SummaryWriter(config.tensorboard_dir) - monitoring_api.write_time_series_step('checkpointing_init_start', 1, monitoring_enabled) + monitoring_api.write_time_series_step('checkpointing_init_start', monitoring_enabled, pyconfig, 1) checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( config.checkpoint_dir, @@ -234,7 +234,7 @@ def train_loop(config, state=None): config.save_period, ) - monitoring_api.write_time_series_step('checkpointing_init_end', 1, monitoring_enabled) + monitoring_api.write_time_series_step('checkpointing_init_end', monitoring_enabled, pyconfig, 1) # Initial PRNG Keys init_rng, nextrng = random.split(random.PRNGKey(config.init_weights_seed), 2) @@ -318,7 +318,7 @@ def train_loop(config, state=None): if step == 0: max_utils.activate_profiler(config) - monitoring_api.write_time_series_step('checkpoint_test_run_end', config.steps, monitoring_enabled) + monitoring_api.write_time_series_step('checkpoint_test_run_end', monitoring_enabled, pyconfig, config.steps) max_utils.deactivate_profiler(config) writer.close() return state diff --git a/requirements.txt b/requirements.txt index 40851fa81..007cc3388 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,8 @@ absl-py argparse cloud-tpu-diagnostics datetime -google-cloud-compute==1.6.1 -google-cloud-monitoring==2.11.3 +google-cloud-compute +google-cloud-monitoring google-cloud-storage flax ml-collections From 3e374e85a9b02fc2f77129e9f649ad563223d9ef Mon Sep 17 00:00:00 2001 From: Priyanka Ganesha Date: Mon, 2 Oct 2023 13:52:06 -0700 Subject: [PATCH 08/10] pylint --- MaxText/monitoring_api.py | 4 ++-- MaxText/pyconfig.py | 5 ++--- MaxText/tests/cloud_monitoring_test.py | 5 ++++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/MaxText/monitoring_api.py b/MaxText/monitoring_api.py index d2c2785df..e244ac934 100644 --- a/MaxText/monitoring_api.py +++ b/MaxText/monitoring_api.py @@ -77,7 +77,7 @@ def write_time_series_step(metric_name, monitoring_enabled, pyconfig, step=1): project_id = get_project() if not monitoring_enabled: - return + return [] client = get_metrics_service_client() project_name = f"projects/{project_id}" @@ -166,7 +166,7 @@ def get_time_series_step_data(metric_name): metric_name=metric_name, worker_id=instance_name ), }) - + result = client.query_time_series(request) return result.time_series_data diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 7232168f9..9659d95e2 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -19,7 +19,6 @@ import math import os -import subprocess import sys import yaml @@ -90,8 +89,8 @@ def user_init(raw_keys): raw_keys["run_name"] = os.environ.get("JOBSET_NAME") #using XPK default run_name = raw_keys["run_name"] assert run_name, "Erroring out, need a real run_name" - assert ((raw_keys['cloud_zone']!="" or not raw_keys['enable_cloud_monitoring']), - "You must provide cloud_zone if cloud monitoring is enabled") + assert ((raw_keys['cloud_zone']!="" or not raw_keys['enable_cloud_monitoring'])),\ + "You must provide cloud_zone if cloud monitoring is enabled" base_output_directory = raw_keys["base_output_directory"] validate_gcs_bucket_name(base_output_directory, "base_output_directory") dataset_path = raw_keys["dataset_path"] diff --git a/MaxText/tests/cloud_monitoring_test.py b/MaxText/tests/cloud_monitoring_test.py index 4e2318ab6..42e532f72 100644 --- a/MaxText/tests/cloud_monitoring_test.py +++ b/MaxText/tests/cloud_monitoring_test.py @@ -27,7 +27,10 @@ class CloudMonitoringTests(unittest.TestCase): """Test for writing time series step using monitoring_api.py""" def test_write_time_series_step(self): - pyconfig.initialize(sys.argv + ['configs/base.yml'], per_device_batch_size=1, run_name='test', cloud_zone='us-central2-b') + pyconfig.initialize(sys.argv + ['configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'], + logical_axis_rules = [['batch', 'data']], + data_sharding = ['data'], + cloud_zone='us-central2-b') monitoring_api.create_custom_metric('test_metric', "This is an example metric") create_time_series_result = monitoring_api.write_time_series_step('test_metric', True, pyconfig, 1) query_time_series_result = monitoring_api.get_time_series_step_data('test_metric') From a5767b485bb12427b003f183f7c69cd6985ce6fd Mon Sep 17 00:00:00 2001 From: Priyanka Ganesha Date: Mon, 2 Oct 2023 14:00:16 -0700 Subject: [PATCH 09/10] change enable_monitoring --- MaxText/configs/base.yml | 2 +- MaxText/tests/cloud_monitoring_test.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 2fc059923..cf9ce4259 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -177,6 +177,6 @@ stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds use_iota_embed: False #Monitoring parameters - Export in-workload metrics to Cloud monitoring -enable_cloud_monitoring: True +enable_cloud_monitoring: False cloud_monitoring_dashboard: "https://pantheon.corp.google.com/monitoring/dashboards?project=" cloud_zone: "" # zone name for cloud jobs - used for cloud metrics emitting diff --git a/MaxText/tests/cloud_monitoring_test.py b/MaxText/tests/cloud_monitoring_test.py index 42e532f72..582389017 100644 --- a/MaxText/tests/cloud_monitoring_test.py +++ b/MaxText/tests/cloud_monitoring_test.py @@ -30,6 +30,7 @@ def test_write_time_series_step(self): pyconfig.initialize(sys.argv + ['configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'], logical_axis_rules = [['batch', 'data']], data_sharding = ['data'], + enable_cloud_monitoring=True, cloud_zone='us-central2-b') monitoring_api.create_custom_metric('test_metric', "This is an example metric") create_time_series_result = monitoring_api.write_time_series_step('test_metric', True, pyconfig, 1) From 48c5dc20239d231c6823ca4f340b316f1da80ce4 Mon Sep 17 00:00:00 2001 From: Priyanka Ganesha Date: Fri, 1 Dec 2023 08:55:53 -0800 Subject: [PATCH 10/10] address comments --- MaxText/configs/base.yml | 2 +- MaxText/monitoring_api.py | 39 +++----------------------- MaxText/tests/cloud_monitoring_test.py | 2 ++ 3 files changed, 7 insertions(+), 36 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index cf9ce4259..8979b8dfa 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -178,5 +178,5 @@ use_iota_embed: False #Monitoring parameters - Export in-workload metrics to Cloud monitoring enable_cloud_monitoring: False -cloud_monitoring_dashboard: "https://pantheon.corp.google.com/monitoring/dashboards?project=" +cloud_monitoring_dashboard: "https://pantheon.corp.google.com/monitoring/dashboards?" cloud_zone: "" # zone name for cloud jobs - used for cloud metrics emitting diff --git a/MaxText/monitoring_api.py b/MaxText/monitoring_api.py index e244ac934..91b4f161e 100644 --- a/MaxText/monitoring_api.py +++ b/MaxText/monitoring_api.py @@ -8,29 +8,11 @@ from google.cloud import monitoring_v3 from google.cloud import compute_v1 from google.api import metric_pb2 -import requests import time import os import max_logging -def get_metadata(project_id, zone, instance_id): - """ - Fetches metadata - - Args: - project_id - zone - instance_id - - Returns: - metadata as json - """ - r = requests.get(url="https://compute.googleapis.com/compute/v1/projects/\ - {project_id}/zones/{zone}/instances/{instance_id}") - metadata = r.json() - return metadata - def create_custom_metric(metric_name, description): """ Creates a custom metric @@ -99,13 +81,7 @@ def write_time_series_step(metric_name, monitoring_enabled, pyconfig, step=1): "%d %b %Y %H:%M:%S UTC", time.gmtime(seconds_since_epoch_utc) ) max_logging.log( - "Emitting metric ", - metric_name, - " for step = ", - step, - " at: ", - event_time, - ) + f"Emitting metric {metric_name} for step = {step} at: {event_time}") instance_id = get_instance_id(project_id, zone) @@ -126,18 +102,11 @@ def write_time_series_step(metric_name, monitoring_enabled, pyconfig, step=1): ) ] - client.create_time_series(name=project_name, time_series=[series], metadata=get_metadata(project_id, zone, instance_id)) + client.create_time_series(name=project_name, time_series=[series]) dashboard_link = pyconfig.config.cloud_monitoring_dashboard+project_name max_logging.log( - "Time series added for step", - step, - "and instance_id ", - instance_id, - " and zone ", - zone, - "\nView dashboards or use metrics: ", - dashboard_link, - ) + f"Time series added for step {step} and instance_id {instance_id} and zone {zone}\ + \n View dashboards or use metrics: {dashboard_link}") return [series] def get_time_series_step_data(metric_name): diff --git a/MaxText/tests/cloud_monitoring_test.py b/MaxText/tests/cloud_monitoring_test.py index 582389017..f62e75172 100644 --- a/MaxText/tests/cloud_monitoring_test.py +++ b/MaxText/tests/cloud_monitoring_test.py @@ -30,6 +30,8 @@ def test_write_time_series_step(self): pyconfig.initialize(sys.argv + ['configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'], logical_axis_rules = [['batch', 'data']], data_sharding = ['data'], + base_output_directory = "gs://max-experiments/", + dataset_path = "gs://maxtext-dataset/", enable_cloud_monitoring=True, cloud_zone='us-central2-b') monitoring_api.create_custom_metric('test_metric', "This is an example metric")