From e627b5e65a232f7aa6ac4b55c735fad32f87acdd Mon Sep 17 00:00:00 2001 From: ML Metrics Team Date: Sun, 24 Mar 2024 22:02:07 -0700 Subject: [PATCH] Implement average word count metric. PiperOrigin-RevId: 618725839 --- ml_metrics/_src/aggregates/nlp.py | 104 +++++++++++++++++++++ ml_metrics/_src/aggregates/nlp_test.py | 124 +++++++++++++++++++++++++ ml_metrics/_src/metrics/nlp.py | 34 +++++++ ml_metrics/_src/metrics/nlp_test.py | 32 +++++++ 4 files changed, 294 insertions(+) create mode 100644 ml_metrics/_src/aggregates/nlp.py create mode 100644 ml_metrics/_src/aggregates/nlp_test.py create mode 100644 ml_metrics/_src/metrics/nlp.py create mode 100644 ml_metrics/_src/metrics/nlp_test.py diff --git a/ml_metrics/_src/aggregates/nlp.py b/ml_metrics/_src/aggregates/nlp.py new file mode 100644 index 00000000..dac27901 --- /dev/null +++ b/ml_metrics/_src/aggregates/nlp.py @@ -0,0 +1,104 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Individual NLP-based metrics.""" + +from collections.abc import Sequence +import dataclasses +import re + +from ml_metrics._src.aggregates import base +from ml_metrics._src.aggregates import utils + + +MeanState = utils.MeanState + + +@dataclasses.dataclass() +class AvgCharCount(base.MergeableMetric): + """Average character count metric. + + The average character count is the mean number of alphabetical characters in + the non-missing texts. + """ + + _state: MeanState = dataclasses.field(default_factory=MeanState, init=False) + + @property + def state(self) -> MeanState: + return self._state + + def add(self, texts: Sequence[str|None]) -> float: + char_count = 0 + non_missing_text_count = 0 + for text in texts: + if text is not None: + cleaned_up = re.sub(r'[^a-zA-Z]', '', text) + char_count += len(cleaned_up) + non_missing_text_count += 1 + + batch_state = MeanState(total=char_count, count=non_missing_text_count) + self._state += batch_state + return batch_state.result() + + def merge(self, other: 'AvgCharCount'): + self._state += other.state + + def result(self) -> float: + return self._state.result() + + +@dataclasses.dataclass(kw_only=True) +class AvgCharCountMaker(base.MetricMaker): + """Average character count metric maker.""" + + def make(self): + return AvgCharCount() + + +@dataclasses.dataclass(kw_only=True) +class AvgWordCount(base.MergeableMetric): + """Average word count metric.""" + + _state: MeanState = dataclasses.field(default_factory=MeanState, init=False) + + @property + def state(self) -> MeanState: + return self._state + + def add(self, texts: Sequence[str|None]) -> float: + word_count = 0 + non_missing_text_count = 0 + for text in texts: + if text is not None: + words = re.sub(r'[^a-zA-Z ]', '', text).split(' ') + word_count += len(words) + non_missing_text_count += 1 + + batch_state = MeanState(total=word_count, count=non_missing_text_count) + self._state += batch_state + return batch_state.result() + + def merge(self, other: 'AvgWordCount'): + self._state += other.state + + def result(self) -> float: + return self._state.result() + + +@dataclasses.dataclass(kw_only=True) +class AvgWordCountMaker(base.MetricMaker): + """Average word count metric maker.""" + + def make(self): + return AvgWordCount() diff --git a/ml_metrics/_src/aggregates/nlp_test.py b/ml_metrics/_src/aggregates/nlp_test.py new file mode 100644 index 00000000..777bc45f --- /dev/null +++ b/ml_metrics/_src/aggregates/nlp_test.py @@ -0,0 +1,124 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for nlp.""" + +from absl.testing import parameterized +from ml_metrics._src.aggregates import nlp +from ml_metrics._src.aggregates import utils + +from absl.testing import absltest + + +class SimpleMetricsTest(parameterized.TestCase): + + @parameterized.named_parameters([ + dict( + testcase_name='avg_char_count_metric', + maker=nlp.AvgCharCountMaker, + item_count=5, + existing_text_count=2, + ), + dict( + testcase_name='avg_word_count_metric', + maker=nlp.AvgWordCountMaker, + item_count=3, + existing_text_count=2, + ), + ]) + def test_compute_metric( + self, maker, item_count, existing_text_count + ): + batch = ['abc', 'd e!'] + avg_item_metric = maker().make() + batch_result = avg_item_metric.add(batch) + + self.assertAlmostEqual( + batch_result, float(item_count / existing_text_count) + ) + expected_state = utils.MeanState(item_count, existing_text_count) + self.assertEqual(avg_item_metric.state, expected_state) + self.assertAlmostEqual( + avg_item_metric.result(), float(item_count / existing_text_count) + ) + + @parameterized.named_parameters([ + dict( + testcase_name='avg_char_count_metric', + maker=nlp.AvgCharCountMaker, + batch_result=float(2/1), + item_count=7, + existing_text_count=3, + ), + dict( + testcase_name='avg_word_count_metric', + maker=nlp.AvgWordCountMaker, + batch_result=float(1/1), + item_count=4, + existing_text_count=3, + ), + ]) + def test_avg_char_count_metric_add( + self, maker, batch_result, item_count, existing_text_count + ): + avg_item_metric = maker().make() + + batch_0 = ['abc', 'd e!'] + avg_item_metric.add(batch_0) + + batch_1 = ['fi'] + batch_1_result = avg_item_metric.add(batch_1) + self.assertAlmostEqual(batch_1_result, batch_result) + + expected_updated_state = utils.MeanState(item_count, existing_text_count) + self.assertEqual(avg_item_metric.state, expected_updated_state) + self.assertAlmostEqual( + avg_item_metric.result(), float(item_count / existing_text_count) + ) + + @parameterized.named_parameters([ + dict( + testcase_name='avg_char_count_metric', + maker=nlp.AvgCharCountMaker, + item_count=7, + existing_text_count=3, + ), + dict( + testcase_name='avg_word_count_metric', + maker=nlp.AvgWordCountMaker, + item_count=3, + existing_text_count=3, + ), + ]) + def test_avg_char_count_metric_merge( + self, maker, item_count, existing_text_count + ): + batch_0 = ['abc', 'de'] + avg_item_metric_0 = maker().make() + avg_item_metric_0.add(batch_0) + + batch_1 = ['fi'] + avg_item_metric_1 = maker().make() + avg_item_metric_1.add(batch_1) + + avg_item_metric_0.merge(avg_item_metric_1) + + expected_state = utils.MeanState(item_count, existing_text_count) + self.assertEqual(avg_item_metric_0.state, expected_state) + self.assertAlmostEqual( + avg_item_metric_0.result(), float(item_count / existing_text_count) + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/ml_metrics/_src/metrics/nlp.py b/ml_metrics/_src/metrics/nlp.py new file mode 100644 index 00000000..9e3835bd --- /dev/null +++ b/ml_metrics/_src/metrics/nlp.py @@ -0,0 +1,34 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Individual NLP based metrics.""" + +from collections.abc import Sequence + +from ml_metrics import aggregates +from ml_metrics._src.aggregates import nlp + + +def avg_char_count(texts: Sequence[str|None]) -> float: + """Compute average character count metric. + + The average character count is the mean number of alphabetical characters in + the non-missing texts. + + Args: + texts: Sequence of texts. + + Returns: + Metric value. + """ + return aggregates.MergeableMetricAggFn(metric=nlp.AvgCharCountMaker())(texts) diff --git a/ml_metrics/_src/metrics/nlp_test.py b/ml_metrics/_src/metrics/nlp_test.py new file mode 100644 index 00000000..b3ffb88c --- /dev/null +++ b/ml_metrics/_src/metrics/nlp_test.py @@ -0,0 +1,32 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for nlp.""" + +from ml_metrics._src.metrics import nlp +from absl.testing import absltest + + +class NlpTest(absltest.TestCase): + + def test_avg_char_count(self): + texts = ['abc', 'd e!', None] + result = nlp.avg_char_count(texts) + self.assertAlmostEqual(float(5/2), result) + + def test_avg_char_count_empty(self): + result = nlp.avg_char_count([]) + self.assertAlmostEqual(float(0), result) + +if __name__ == '__main__': + absltest.main()