|
| 1 | +""" |
| 2 | + Copyright 2023 Google LLC |
| 3 | +
|
| 4 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + you may not use this file except in compliance with the License. |
| 6 | + You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | + Unless required by applicable law or agreed to in writing, software |
| 11 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + See the License for the specific language governing permissions and |
| 14 | + limitations under the License. |
| 15 | + """ |
| 16 | + |
| 17 | +# pylint: skip-file |
| 18 | +"""Reads and asserts over target values""" |
| 19 | +from absl import app |
| 20 | +from typing import Sequence |
| 21 | +from math import isclose |
| 22 | +from google.cloud import storage |
| 23 | +import json |
| 24 | + |
| 25 | + |
| 26 | +def compute_avg_metric(metrics_file, target, start_line=10): |
| 27 | + """ Reads and computes average of target value |
| 28 | + If start_line is negative then uses the last lines, e.g. start from end + 1 - |start_line|""" |
| 29 | + |
| 30 | + |
| 31 | + avg = 0 |
| 32 | + i = 0 |
| 33 | + with open(metrics_file, 'r', encoding='utf8') as file: |
| 34 | + lines = file.readlines() |
| 35 | + if start_line < 0: |
| 36 | + start_line = len(lines) + start_line |
| 37 | + for line in lines: |
| 38 | + # skip the first start_line lines for burn in |
| 39 | + if i >= start_line: |
| 40 | + vals = json.loads(line) |
| 41 | + avg += vals[target] |
| 42 | + i+=1 |
| 43 | + avg /= (i-start_line) |
| 44 | + |
| 45 | + return avg |
| 46 | + |
| 47 | + |
| 48 | +def assert_metric_average(metrics_file, threshold, target): |
| 49 | + avg_value = compute_avg_metric(metrics_file, target) |
| 50 | + # Checks for acceptable performance by asserting that the average metric (e.g. TFLOPs) |
| 51 | + # is greater than the threshold. |
| 52 | + print(f'avg value of target {target} is {avg_value}') |
| 53 | + assert avg_value >= float(threshold) |
| 54 | + print('assert metric average passed.') |
| 55 | + |
| 56 | +def test_final_loss(metrics_file, target_loss): |
| 57 | + target_loss = float(target_loss) |
| 58 | + with open(metrics_file, 'r', encoding='utf8') as metrics: |
| 59 | + use_last_n_data = 10 |
| 60 | + avg_final_loss = compute_avg_metric(metrics_file, 'learning/loss', start_line= -1 * use_last_n_data) |
| 61 | + print(f"Mean of last {use_last_n_data} losses is {avg_final_loss}") |
| 62 | + print(f"Target loss is {target_loss}") |
| 63 | + assert avg_final_loss < target_loss |
| 64 | + print('Final loss test passed.') |
| 65 | + |
| 66 | +def test_checkpointing(metrics_file, target, dataset_type): |
| 67 | + """Asserts over loss values from loaded checkpoint""" |
| 68 | + metrics_file_saved = 'saved_' + metrics_file |
| 69 | + metrics_file_restored = 'restored_' + metrics_file |
| 70 | + |
| 71 | + with open(metrics_file_saved, 'r', encoding='utf8') as saved,\ |
| 72 | + open(metrics_file_restored, 'r', encoding='utf8') as restored: |
| 73 | + saved_loss = json.loads(saved.readlines()[-1])[target] |
| 74 | + restored_loss = json.loads(restored.readlines()[0])[target] |
| 75 | + # Checks that checkpoint restore was successful by comparing loss of last |
| 76 | + # step in saved checkpoint to loss of first step in restored checkpoint |
| 77 | + print("saved loss: ", saved_loss) |
| 78 | + print("restored loss: ", restored_loss) |
| 79 | + if dataset_type=='c4': |
| 80 | + assert isclose(saved_loss, restored_loss, rel_tol=0.1) |
| 81 | + elif dataset_type=='c4-array_record': |
| 82 | + assert saved_loss==restored_loss |
| 83 | + else: |
| 84 | + raise ValueError(f"Unknown dataset_type {dataset_type}. dataset_type must be c4, c4-array_record or synthetic") |
| 85 | + print('checkpointing test passed.') |
| 86 | + |
| 87 | +def test_determinism(metrics_file, target): |
| 88 | + """Asserts over loss values from two runs""" |
| 89 | + run_1 = 'run_1_' + metrics_file |
| 90 | + run_2 = 'run_2_' + metrics_file |
| 91 | + |
| 92 | + with open(run_1, 'r', encoding='utf8') as run_1_file,\ |
| 93 | + open(run_2, 'r', encoding='utf8') as run_2_file: |
| 94 | + run_1_loss = json.loads(run_1_file.readlines()[-1])[target] |
| 95 | + run_2_loss = json.loads(run_2_file.readlines()[-1])[target] |
| 96 | + # Check that the two runs have the same loss |
| 97 | + print(f"Run 1 loss:{run_1_loss}", flush=True) |
| 98 | + print(f"Run 2 loss:{run_2_loss}", flush=True) |
| 99 | + assert run_1_loss==run_2_loss |
| 100 | + print('determinism test passed.') |
| 101 | + |
| 102 | +def test_vocab_creation(target): |
| 103 | + bucket_name = target.split("/")[2] |
| 104 | + vocab_path = "/".join(target.split("/")[3:]) |
| 105 | + storage_client = storage.Client() |
| 106 | + assert storage.Blob(bucket=storage_client.bucket(bucket_name), name=vocab_path).exists(storage_client) |
| 107 | + print('vocab creation test passed.') |
| 108 | + |
| 109 | +def test_start_step(metrics_file, start_step_target): |
| 110 | + with open(metrics_file, 'r', encoding='utf8') as metrics: |
| 111 | + start_step = json.loads(metrics.readlines()[0])["step"] |
| 112 | + print(f"Start step is {start_step}, start step target is {start_step_target}") |
| 113 | + assert start_step==float(start_step_target) |
| 114 | + print("Start step test passed.") |
| 115 | + |
| 116 | +def main(argv: Sequence[str]) -> None: |
| 117 | + |
| 118 | + _, test_scenario, *test_vars = argv |
| 119 | + |
| 120 | + if test_scenario == 'metrics_average': |
| 121 | + assert_metric_average(*test_vars) |
| 122 | + elif test_scenario == 'checkpoint_save_restore': |
| 123 | + test_checkpointing(*test_vars, dataset_type='c4') |
| 124 | + elif test_scenario == 'grain_checkpoint_save_restore': |
| 125 | + test_checkpointing(*test_vars, dataset_type='c4-array_record') |
| 126 | + elif test_scenario == 'determinism': |
| 127 | + test_determinism(*test_vars) |
| 128 | + elif test_scenario == 'vocab_creation': |
| 129 | + test_vocab_creation(*test_vars) |
| 130 | + elif test_scenario == 'final_loss': |
| 131 | + test_final_loss(*test_vars) |
| 132 | + elif test_scenario == 'test_start_step': |
| 133 | + test_start_step(*test_vars) |
| 134 | + else: |
| 135 | + raise ValueError(f"Unrecognized test_scenario {test_scenario}") |
| 136 | + |
| 137 | + |
| 138 | +if __name__ == "__main__": |
| 139 | + app.run(main) |
0 commit comments