Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom metrics emitting #170

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,8 @@ stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds

# Use iota operator in Embed
use_iota_embed: False

#Monitoring parameters - Export in-workload metrics to Cloud monitoring
enable_cloud_monitoring: False
cloud_monitoring_dashboard: "https://pantheon.corp.google.com/monitoring/dashboards?"
cloud_zone: "" # zone name for cloud jobs - used for cloud metrics emitting
4 changes: 4 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import functools

import max_logging
import monitoring_api

import numpy as np
import jax
Expand Down Expand Up @@ -227,6 +228,9 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager):
state = unbox_logicallypartioned_trainstate(state)
return state, state_mesh_annotations

def register_train_metrics(metric_name, metric_description):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it should be in monitoring_api.py (notice that monitoring_api.py isn't ML, MaxText or Max specific AFAICT)

monitoring_api.create_custom_metric(metric_name, metric_description)


# Learning Rate Schedule
# -----------------------------------------------------------------------------
Expand Down
182 changes: 182 additions & 0 deletions MaxText/monitoring_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# pylint: disable=unused-argument, no-name-in-module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It really feels like this file is in the wrong project. I wonder if Branden/Surbhi could recommend a better home for it?

"""
Cloud Monitoring API v3 Prototype
"""

import subprocess
import sys
from google.cloud import monitoring_v3
from google.cloud import compute_v1
from google.api import metric_pb2
import time
import os

import max_logging

def create_custom_metric(metric_name, description):
"""
Creates a custom metric

Args:
metric_name
description

Returns:
Response from create request
"""
project_id = get_project()
project_name = f"projects/{project_id}"

client = monitoring_v3.MetricServiceClient()

descriptor = metric_pb2.MetricDescriptor()
descriptor.type = "custom.googleapis.com/" + metric_name
descriptor.metric_kind = metric_pb2.MetricDescriptor.MetricKind.GAUGE
descriptor.value_type = metric_pb2.MetricDescriptor.ValueType.DOUBLE
descriptor.description = description

request = monitoring_v3.CreateMetricDescriptorRequest(
name=project_name,
metric_descriptor=descriptor
)

response = client.create_metric_descriptor(request=request)

return response


def write_time_series_step(metric_name, monitoring_enabled, pyconfig, step=1):
"""
Writes a time series object for a specified custom metric

Args:
metric_name
monitoring_enabled
step
"""

zone = pyconfig.config.cloud_zone
project_id = get_project()

if not monitoring_enabled:
return []

client = get_metrics_service_client()
project_name = f"projects/{project_id}"

seconds_since_epoch_utc = time.time()
nanos_since_epoch_utc = int(
(seconds_since_epoch_utc - int(seconds_since_epoch_utc)) * 10**9
)
interval = monitoring_v3.types.TimeInterval(
{
"end_time": {
"seconds": int(seconds_since_epoch_utc),
"nanos": nanos_since_epoch_utc,
}
}
)

event_time = time.strftime(
"%d %b %Y %H:%M:%S UTC", time.gmtime(seconds_since_epoch_utc)
)
max_logging.log(
f"Emitting metric {metric_name} for step = {step} at: {event_time}")

instance_id = get_instance_id(project_id, zone)

series = monitoring_v3.types.TimeSeries()
series.metric.type = "custom.googleapis.com/" + metric_name
series.resource.type = "gce_instance"
series.resource.labels["instance_id"] = str(instance_id)
series.resource.labels["zone"] = zone
series.metric.labels["step_num"] = str(step)
series.metric.labels["worker"] = os.uname().nodename
series.metric.labels["event_time"] = event_time
series.points = [
monitoring_v3.types.Point(
interval=interval,
value=monitoring_v3.types.TypedValue(
double_value=step
),
)
]

client.create_time_series(name=project_name, time_series=[series])
dashboard_link = pyconfig.config.cloud_monitoring_dashboard+project_name
max_logging.log(
f"Time series added for step {step} and instance_id {instance_id} and zone {zone}\
\n View dashboards or use metrics: {dashboard_link}")
return [series]

def get_time_series_step_data(metric_name):
"""
Retrieves time series data

Args:
metric_name
"""
project_id = get_project()
project_name = f"projects/{project_id}"
instance_name = os.uname().nodename

mql = """
fetch gce_instance
| metric 'custom.googleapis.com/{metric_name}'
| filter (metric.worker == '{worker_id}')
| every 1m
| within -1d, 1d # one day, starting 1 day ago
"""

client = get_query_service_client()
request = monitoring_v3.QueryTimeSeriesRequest({
"name": project_name,
"query": mql.format(
metric_name=metric_name, worker_id=instance_name
),
})

result = client.query_time_series(request)
return result.time_series_data


def get_instance_id(project_id, zone):
"""
Fetches instance id of a node

Args:
project_id
zone
"""
client = get_compute_instances_client()
instance_name = os.uname().nodename
instance = client.get(project=project_id, zone=zone, instance=instance_name)
return instance.id

def get_project():
"""
Fetches id of project in use
"""
completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True)
project_outputs = completed_command.stdout.decode().strip().split('\n')
if len(project_outputs) < 1 or project_outputs[-1]=='':
sys.exit("You must specify the project in the PROJECT flag or set it with 'gcloud config set project <project>'")
return project_outputs[-1]

def get_compute_instances_client():
"""
Fetches cloud compute instances client
"""
return compute_v1.InstancesClient()

def get_metrics_service_client():
"""
Fetches cloud monitoring API client
"""
return monitoring_v3.MetricServiceClient()

def get_query_service_client():
"""
Fetches cloud monitoring query service client
"""
return monitoring_v3.QueryServiceClient()
2 changes: 2 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def user_init(raw_keys):
raw_keys["run_name"] = os.environ.get("JOBSET_NAME") #using XPK default
run_name = raw_keys["run_name"]
assert run_name, "Erroring out, need a real run_name"
assert ((raw_keys['cloud_zone']!="" or not raw_keys['enable_cloud_monitoring'])),\
"You must provide cloud_zone if cloud monitoring is enabled"
base_output_directory = raw_keys["base_output_directory"]
validate_gcs_bucket_name(base_output_directory, "base_output_directory")
dataset_path = raw_keys["dataset_path"]
Expand Down
45 changes: 45 additions & 0 deletions MaxText/tests/cloud_monitoring_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

""" Tests for Cloud Monitoring API """
import sys
import jax
import unittest

import monitoring_api
import pyconfig

jax.config.update('jax_platform_name', 'cpu')

class CloudMonitoringTests(unittest.TestCase):
"""Test for writing time series step using monitoring_api.py"""
def test_write_time_series_step(self):
pyconfig.initialize(sys.argv + ['configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'],
logical_axis_rules = [['batch', 'data']],
data_sharding = ['data'],
base_output_directory = "gs://max-experiments/",
dataset_path = "gs://maxtext-dataset/",
enable_cloud_monitoring=True,
cloud_zone='us-central2-b')
monitoring_api.create_custom_metric('test_metric', "This is an example metric")
create_time_series_result = monitoring_api.write_time_series_step('test_metric', True, pyconfig, 1)
query_time_series_result = monitoring_api.get_time_series_step_data('test_metric')
self.assertEqual(create_time_series_result, query_time_series_result)


if __name__ == '__main__':
unittest.main()

19 changes: 19 additions & 0 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from cloud_tpu_diagnostics.configuration import stack_trace_configuration

import max_logging
import monitoring_api
rwitten marked this conversation as resolved.
Show resolved Hide resolved
cc.initialize_cache(os.path.expanduser("~/jax_cache"))

# https://arxiv.org/pdf/2204.02311.pdf Appendix B
Expand Down Expand Up @@ -211,13 +212,30 @@ def train_loop(config, state=None):
Returns:

"""

monitoring_enabled = config.enable_cloud_monitoring

if monitoring_enabled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move registrations and all this logic into a standalone function. register_all_train_metrics

max_utils.register_train_metrics('checkpointint_init_start', "Checkpointing Initialization Start")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in checkpointing!

max_utils.register_train_metrics('checkpointing_init_end', "Checkpointing Initialization End")
max_utils.register_train_metrics('checkpoint_test_run_start', "Checkpointing Test Run Start")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do test_run_start and test_run_end mean in this context?

max_utils.register_train_metrics('checkpoint_test_run_end', "Checkpointing Test Run End")

monitoring_api.write_time_series_step('checkpoint_test_run_start', monitoring_enabled, pyconfig, 0)

writer = SummaryWriter(config.tensorboard_dir)

monitoring_api.write_time_series_step('checkpointing_init_start', monitoring_enabled, pyconfig, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this step 1?


checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
config.checkpoint_dir,
config.enable_checkpointing,
config.async_checkpointing,
config.save_period,
)

monitoring_api.write_time_series_step('checkpointing_init_end', monitoring_enabled, pyconfig, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this step 1?


# Initial PRNG Keys
init_rng, nextrng = random.split(random.PRNGKey(config.init_weights_seed), 2)

Expand Down Expand Up @@ -300,6 +318,7 @@ def train_loop(config, state=None):
if step == 0:
max_utils.activate_profiler(config)

monitoring_api.write_time_series_step('checkpoint_test_run_end', monitoring_enabled, pyconfig, config.steps)
max_utils.deactivate_profiler(config)
writer.close()
return state
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ absl-py
argparse
cloud-tpu-diagnostics
datetime
google-cloud-compute
google-cloud-monitoring
google-cloud-storage
flax
ml-collections
Expand Down
Loading