From e627b5e65a232f7aa6ac4b55c735fad32f87acdd Mon Sep 17 00:00:00 2001
From: ML Metrics Team <ml-metrics-dev@google.com>
Date: Sun, 24 Mar 2024 22:02:07 -0700
Subject: [PATCH] Implement average word count metric.

PiperOrigin-RevId: 618725839
---
 ml_metrics/_src/aggregates/nlp.py      | 104 +++++++++++++++++++++
 ml_metrics/_src/aggregates/nlp_test.py | 124 +++++++++++++++++++++++++
 ml_metrics/_src/metrics/nlp.py         |  34 +++++++
 ml_metrics/_src/metrics/nlp_test.py    |  32 +++++++
 4 files changed, 294 insertions(+)
 create mode 100644 ml_metrics/_src/aggregates/nlp.py
 create mode 100644 ml_metrics/_src/aggregates/nlp_test.py
 create mode 100644 ml_metrics/_src/metrics/nlp.py
 create mode 100644 ml_metrics/_src/metrics/nlp_test.py

diff --git a/ml_metrics/_src/aggregates/nlp.py b/ml_metrics/_src/aggregates/nlp.py
new file mode 100644
index 00000000..dac27901
--- /dev/null
+++ b/ml_metrics/_src/aggregates/nlp.py
@@ -0,0 +1,104 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Individual NLP-based metrics."""
+
+from collections.abc import Sequence
+import dataclasses
+import re
+
+from ml_metrics._src.aggregates import base
+from ml_metrics._src.aggregates import utils
+
+
+MeanState = utils.MeanState
+
+
+@dataclasses.dataclass()
+class AvgCharCount(base.MergeableMetric):
+  """Average character count metric.
+
+  The average character count is the mean number of alphabetical characters in
+  the non-missing texts.
+  """
+
+  _state: MeanState = dataclasses.field(default_factory=MeanState, init=False)
+
+  @property
+  def state(self) -> MeanState:
+    return self._state
+
+  def add(self, texts: Sequence[str|None]) -> float:
+    char_count = 0
+    non_missing_text_count = 0
+    for text in texts:
+      if text is not None:
+        cleaned_up = re.sub(r'[^a-zA-Z]', '', text)
+        char_count += len(cleaned_up)
+        non_missing_text_count += 1
+
+    batch_state = MeanState(total=char_count, count=non_missing_text_count)
+    self._state += batch_state
+    return batch_state.result()
+
+  def merge(self, other: 'AvgCharCount'):
+    self._state += other.state
+
+  def result(self) -> float:
+    return self._state.result()
+
+
+@dataclasses.dataclass(kw_only=True)
+class AvgCharCountMaker(base.MetricMaker):
+  """Average character count metric maker."""
+
+  def make(self):
+    return AvgCharCount()
+
+
+@dataclasses.dataclass(kw_only=True)
+class AvgWordCount(base.MergeableMetric):
+  """Average word count metric."""
+
+  _state: MeanState = dataclasses.field(default_factory=MeanState, init=False)
+
+  @property
+  def state(self) -> MeanState:
+    return self._state
+
+  def add(self, texts: Sequence[str|None]) -> float:
+    word_count = 0
+    non_missing_text_count = 0
+    for text in texts:
+      if text is not None:
+        words = re.sub(r'[^a-zA-Z ]', '', text).split(' ')
+        word_count += len(words)
+        non_missing_text_count += 1
+
+    batch_state = MeanState(total=word_count, count=non_missing_text_count)
+    self._state += batch_state
+    return batch_state.result()
+
+  def merge(self, other: 'AvgWordCount'):
+    self._state += other.state
+
+  def result(self) -> float:
+    return self._state.result()
+
+
+@dataclasses.dataclass(kw_only=True)
+class AvgWordCountMaker(base.MetricMaker):
+  """Average word count metric maker."""
+
+  def make(self):
+    return AvgWordCount()
diff --git a/ml_metrics/_src/aggregates/nlp_test.py b/ml_metrics/_src/aggregates/nlp_test.py
new file mode 100644
index 00000000..777bc45f
--- /dev/null
+++ b/ml_metrics/_src/aggregates/nlp_test.py
@@ -0,0 +1,124 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for nlp."""
+
+from absl.testing import parameterized
+from ml_metrics._src.aggregates import nlp
+from ml_metrics._src.aggregates import utils
+
+from absl.testing import absltest
+
+
+class SimpleMetricsTest(parameterized.TestCase):
+
+  @parameterized.named_parameters([
+      dict(
+          testcase_name='avg_char_count_metric',
+          maker=nlp.AvgCharCountMaker,
+          item_count=5,
+          existing_text_count=2,
+      ),
+      dict(
+          testcase_name='avg_word_count_metric',
+          maker=nlp.AvgWordCountMaker,
+          item_count=3,
+          existing_text_count=2,
+      ),
+  ])
+  def test_compute_metric(
+      self, maker, item_count, existing_text_count
+  ):
+    batch = ['abc', 'd e!']
+    avg_item_metric = maker().make()
+    batch_result = avg_item_metric.add(batch)
+
+    self.assertAlmostEqual(
+        batch_result, float(item_count / existing_text_count)
+    )
+    expected_state = utils.MeanState(item_count, existing_text_count)
+    self.assertEqual(avg_item_metric.state, expected_state)
+    self.assertAlmostEqual(
+        avg_item_metric.result(), float(item_count / existing_text_count)
+    )
+
+  @parameterized.named_parameters([
+      dict(
+          testcase_name='avg_char_count_metric',
+          maker=nlp.AvgCharCountMaker,
+          batch_result=float(2/1),
+          item_count=7,
+          existing_text_count=3,
+      ),
+      dict(
+          testcase_name='avg_word_count_metric',
+          maker=nlp.AvgWordCountMaker,
+          batch_result=float(1/1),
+          item_count=4,
+          existing_text_count=3,
+      ),
+  ])
+  def test_avg_char_count_metric_add(
+      self, maker, batch_result, item_count, existing_text_count
+  ):
+    avg_item_metric = maker().make()
+
+    batch_0 = ['abc', 'd e!']
+    avg_item_metric.add(batch_0)
+
+    batch_1 = ['fi']
+    batch_1_result = avg_item_metric.add(batch_1)
+    self.assertAlmostEqual(batch_1_result, batch_result)
+
+    expected_updated_state = utils.MeanState(item_count, existing_text_count)
+    self.assertEqual(avg_item_metric.state, expected_updated_state)
+    self.assertAlmostEqual(
+        avg_item_metric.result(), float(item_count / existing_text_count)
+    )
+
+  @parameterized.named_parameters([
+      dict(
+          testcase_name='avg_char_count_metric',
+          maker=nlp.AvgCharCountMaker,
+          item_count=7,
+          existing_text_count=3,
+      ),
+      dict(
+          testcase_name='avg_word_count_metric',
+          maker=nlp.AvgWordCountMaker,
+          item_count=3,
+          existing_text_count=3,
+      ),
+  ])
+  def test_avg_char_count_metric_merge(
+      self, maker, item_count, existing_text_count
+  ):
+    batch_0 = ['abc', 'de']
+    avg_item_metric_0 = maker().make()
+    avg_item_metric_0.add(batch_0)
+
+    batch_1 = ['fi']
+    avg_item_metric_1 = maker().make()
+    avg_item_metric_1.add(batch_1)
+
+    avg_item_metric_0.merge(avg_item_metric_1)
+
+    expected_state = utils.MeanState(item_count, existing_text_count)
+    self.assertEqual(avg_item_metric_0.state, expected_state)
+    self.assertAlmostEqual(
+        avg_item_metric_0.result(), float(item_count / existing_text_count)
+    )
+
+
+if __name__ == '__main__':
+  absltest.main()
diff --git a/ml_metrics/_src/metrics/nlp.py b/ml_metrics/_src/metrics/nlp.py
new file mode 100644
index 00000000..9e3835bd
--- /dev/null
+++ b/ml_metrics/_src/metrics/nlp.py
@@ -0,0 +1,34 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Individual NLP based metrics."""
+
+from collections.abc import Sequence
+
+from ml_metrics import aggregates
+from ml_metrics._src.aggregates import nlp
+
+
+def avg_char_count(texts: Sequence[str|None]) -> float:
+  """Compute average character count metric.
+
+  The average character count is the mean number of alphabetical characters in
+  the non-missing texts.
+
+  Args:
+    texts: Sequence of texts.
+
+  Returns:
+    Metric value.
+  """
+  return aggregates.MergeableMetricAggFn(metric=nlp.AvgCharCountMaker())(texts)
diff --git a/ml_metrics/_src/metrics/nlp_test.py b/ml_metrics/_src/metrics/nlp_test.py
new file mode 100644
index 00000000..b3ffb88c
--- /dev/null
+++ b/ml_metrics/_src/metrics/nlp_test.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for nlp."""
+
+from ml_metrics._src.metrics import nlp
+from absl.testing import absltest
+
+
+class NlpTest(absltest.TestCase):
+
+  def test_avg_char_count(self):
+    texts = ['abc', 'd e!', None]
+    result = nlp.avg_char_count(texts)
+    self.assertAlmostEqual(float(5/2), result)
+
+  def test_avg_char_count_empty(self):
+    result = nlp.avg_char_count([])
+    self.assertAlmostEqual(float(0), result)
+
+if __name__ == '__main__':
+  absltest.main()