-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom metrics emitting #170
base: main
Are you sure you want to change the base?
Changes from all commits
1f99e87
ec2a177
6326514
494ebf6
5ee4fe3
54caff6
c585140
3e374e8
a5767b4
48c5dc2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# pylint: disable=unused-argument, no-name-in-module | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It really feels like this file is in the wrong project. I wonder if Branden/Surbhi could recommend a better home for it? |
||
""" | ||
Cloud Monitoring API v3 Prototype | ||
""" | ||
|
||
import subprocess | ||
import sys | ||
from google.cloud import monitoring_v3 | ||
from google.cloud import compute_v1 | ||
from google.api import metric_pb2 | ||
import time | ||
import os | ||
|
||
import max_logging | ||
|
||
def create_custom_metric(metric_name, description): | ||
""" | ||
Creates a custom metric | ||
|
||
Args: | ||
metric_name | ||
description | ||
|
||
Returns: | ||
Response from create request | ||
""" | ||
project_id = get_project() | ||
project_name = f"projects/{project_id}" | ||
|
||
client = monitoring_v3.MetricServiceClient() | ||
|
||
descriptor = metric_pb2.MetricDescriptor() | ||
descriptor.type = "custom.googleapis.com/" + metric_name | ||
descriptor.metric_kind = metric_pb2.MetricDescriptor.MetricKind.GAUGE | ||
descriptor.value_type = metric_pb2.MetricDescriptor.ValueType.DOUBLE | ||
descriptor.description = description | ||
|
||
request = monitoring_v3.CreateMetricDescriptorRequest( | ||
name=project_name, | ||
metric_descriptor=descriptor | ||
) | ||
|
||
response = client.create_metric_descriptor(request=request) | ||
|
||
return response | ||
|
||
|
||
def write_time_series_step(metric_name, monitoring_enabled, pyconfig, step=1): | ||
""" | ||
Writes a time series object for a specified custom metric | ||
|
||
Args: | ||
metric_name | ||
monitoring_enabled | ||
step | ||
""" | ||
|
||
zone = pyconfig.config.cloud_zone | ||
project_id = get_project() | ||
|
||
if not monitoring_enabled: | ||
return [] | ||
|
||
client = get_metrics_service_client() | ||
project_name = f"projects/{project_id}" | ||
|
||
seconds_since_epoch_utc = time.time() | ||
nanos_since_epoch_utc = int( | ||
(seconds_since_epoch_utc - int(seconds_since_epoch_utc)) * 10**9 | ||
) | ||
interval = monitoring_v3.types.TimeInterval( | ||
{ | ||
"end_time": { | ||
"seconds": int(seconds_since_epoch_utc), | ||
"nanos": nanos_since_epoch_utc, | ||
} | ||
} | ||
) | ||
|
||
event_time = time.strftime( | ||
"%d %b %Y %H:%M:%S UTC", time.gmtime(seconds_since_epoch_utc) | ||
) | ||
max_logging.log( | ||
f"Emitting metric {metric_name} for step = {step} at: {event_time}") | ||
|
||
instance_id = get_instance_id(project_id, zone) | ||
|
||
series = monitoring_v3.types.TimeSeries() | ||
series.metric.type = "custom.googleapis.com/" + metric_name | ||
series.resource.type = "gce_instance" | ||
series.resource.labels["instance_id"] = str(instance_id) | ||
series.resource.labels["zone"] = zone | ||
series.metric.labels["step_num"] = str(step) | ||
series.metric.labels["worker"] = os.uname().nodename | ||
series.metric.labels["event_time"] = event_time | ||
series.points = [ | ||
monitoring_v3.types.Point( | ||
interval=interval, | ||
value=monitoring_v3.types.TypedValue( | ||
double_value=step | ||
), | ||
) | ||
] | ||
|
||
client.create_time_series(name=project_name, time_series=[series]) | ||
dashboard_link = pyconfig.config.cloud_monitoring_dashboard+project_name | ||
max_logging.log( | ||
f"Time series added for step {step} and instance_id {instance_id} and zone {zone}\ | ||
\n View dashboards or use metrics: {dashboard_link}") | ||
return [series] | ||
|
||
def get_time_series_step_data(metric_name): | ||
""" | ||
Retrieves time series data | ||
|
||
Args: | ||
metric_name | ||
""" | ||
project_id = get_project() | ||
project_name = f"projects/{project_id}" | ||
instance_name = os.uname().nodename | ||
|
||
mql = """ | ||
fetch gce_instance | ||
| metric 'custom.googleapis.com/{metric_name}' | ||
| filter (metric.worker == '{worker_id}') | ||
| every 1m | ||
| within -1d, 1d # one day, starting 1 day ago | ||
""" | ||
|
||
client = get_query_service_client() | ||
request = monitoring_v3.QueryTimeSeriesRequest({ | ||
"name": project_name, | ||
"query": mql.format( | ||
metric_name=metric_name, worker_id=instance_name | ||
), | ||
}) | ||
|
||
result = client.query_time_series(request) | ||
return result.time_series_data | ||
|
||
|
||
def get_instance_id(project_id, zone): | ||
""" | ||
Fetches instance id of a node | ||
|
||
Args: | ||
project_id | ||
zone | ||
""" | ||
client = get_compute_instances_client() | ||
instance_name = os.uname().nodename | ||
instance = client.get(project=project_id, zone=zone, instance=instance_name) | ||
return instance.id | ||
|
||
def get_project(): | ||
""" | ||
Fetches id of project in use | ||
""" | ||
completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) | ||
project_outputs = completed_command.stdout.decode().strip().split('\n') | ||
if len(project_outputs) < 1 or project_outputs[-1]=='': | ||
sys.exit("You must specify the project in the PROJECT flag or set it with 'gcloud config set project <project>'") | ||
return project_outputs[-1] | ||
|
||
def get_compute_instances_client(): | ||
""" | ||
Fetches cloud compute instances client | ||
""" | ||
return compute_v1.InstancesClient() | ||
|
||
def get_metrics_service_client(): | ||
""" | ||
Fetches cloud monitoring API client | ||
""" | ||
return monitoring_v3.MetricServiceClient() | ||
|
||
def get_query_service_client(): | ||
""" | ||
Fetches cloud monitoring query service client | ||
""" | ||
return monitoring_v3.QueryServiceClient() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
""" | ||
Copyright 2023 Google LLC | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
https://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
""" Tests for Cloud Monitoring API """ | ||
import sys | ||
import jax | ||
import unittest | ||
|
||
import monitoring_api | ||
import pyconfig | ||
|
||
jax.config.update('jax_platform_name', 'cpu') | ||
|
||
class CloudMonitoringTests(unittest.TestCase): | ||
"""Test for writing time series step using monitoring_api.py""" | ||
def test_write_time_series_step(self): | ||
pyconfig.initialize(sys.argv + ['configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'], | ||
logical_axis_rules = [['batch', 'data']], | ||
data_sharding = ['data'], | ||
base_output_directory = "gs://max-experiments/", | ||
dataset_path = "gs://maxtext-dataset/", | ||
enable_cloud_monitoring=True, | ||
cloud_zone='us-central2-b') | ||
monitoring_api.create_custom_metric('test_metric', "This is an example metric") | ||
create_time_series_result = monitoring_api.write_time_series_step('test_metric', True, pyconfig, 1) | ||
query_time_series_result = monitoring_api.get_time_series_step_data('test_metric') | ||
self.assertEqual(create_time_series_result, query_time_series_result) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,6 +56,7 @@ | |
from cloud_tpu_diagnostics.configuration import stack_trace_configuration | ||
|
||
import max_logging | ||
import monitoring_api | ||
rwitten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cc.initialize_cache(os.path.expanduser("~/jax_cache")) | ||
|
||
# https://arxiv.org/pdf/2204.02311.pdf Appendix B | ||
|
@@ -211,13 +212,30 @@ def train_loop(config, state=None): | |
Returns: | ||
|
||
""" | ||
|
||
monitoring_enabled = config.enable_cloud_monitoring | ||
|
||
if monitoring_enabled: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's move registrations and all this logic into a standalone function. register_all_train_metrics |
||
max_utils.register_train_metrics('checkpointint_init_start', "Checkpointing Initialization Start") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo in checkpointing! |
||
max_utils.register_train_metrics('checkpointing_init_end', "Checkpointing Initialization End") | ||
max_utils.register_train_metrics('checkpoint_test_run_start', "Checkpointing Test Run Start") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do test_run_start and test_run_end mean in this context? |
||
max_utils.register_train_metrics('checkpoint_test_run_end', "Checkpointing Test Run End") | ||
|
||
monitoring_api.write_time_series_step('checkpoint_test_run_start', monitoring_enabled, pyconfig, 0) | ||
|
||
writer = SummaryWriter(config.tensorboard_dir) | ||
|
||
monitoring_api.write_time_series_step('checkpointing_init_start', monitoring_enabled, pyconfig, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this step 1? |
||
|
||
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( | ||
config.checkpoint_dir, | ||
config.enable_checkpointing, | ||
config.async_checkpointing, | ||
config.save_period, | ||
) | ||
|
||
monitoring_api.write_time_series_step('checkpointing_init_end', monitoring_enabled, pyconfig, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this step 1? |
||
|
||
# Initial PRNG Keys | ||
init_rng, nextrng = random.split(random.PRNGKey(config.init_weights_seed), 2) | ||
|
||
|
@@ -300,6 +318,7 @@ def train_loop(config, state=None): | |
if step == 0: | ||
max_utils.activate_profiler(config) | ||
|
||
monitoring_api.write_time_series_step('checkpoint_test_run_end', monitoring_enabled, pyconfig, config.steps) | ||
max_utils.deactivate_profiler(config) | ||
writer.close() | ||
return state | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like it should be in monitoring_api.py (notice that monitoring_api.py isn't ML, MaxText or Max specific AFAICT)