From 600a98d5becf74de5f9546383aace2e86b0b64f3 Mon Sep 17 00:00:00 2001 From: William Baker Date: Thu, 25 Jan 2024 16:13:19 +0000 Subject: [PATCH 1/3] no need to copy objects when annotated --- tracr/rasp/rasp.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tracr/rasp/rasp.py b/tracr/rasp/rasp.py index a074019..43afb23 100644 --- a/tracr/rasp/rasp.py +++ b/tracr/rasp/rasp.py @@ -157,10 +157,9 @@ def annotated(self: RASPExprT, **annotations) -> RASPExprT: def annotate(expr: RASPExprT, **annotations) -> RASPExprT: """Creates a new expr with added annotations.""" - new = expr.copy() # Note that new annotations will overwrite existing ones with matching keys. - new.annotations = {**expr.annotations, **annotations} - return new + expr.annotations = {**expr.annotations, **annotations} + return expr ### S-Ops. From 0593db8064694fb68b14d031f2afdc32545b8b24 Mon Sep 17 00:00:00 2001 From: William Baker Date: Mon, 29 Jan 2024 16:13:11 +0000 Subject: [PATCH 2/3] updated doc string and removed dependencies --- tracr/rasp/rasp.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tracr/rasp/rasp.py b/tracr/rasp/rasp.py index 43afb23..d943d8a 100644 --- a/tracr/rasp/rasp.py +++ b/tracr/rasp/rasp.py @@ -34,7 +34,6 @@ import abc import collections.abc -import copy import enum import functools import itertools @@ -138,10 +137,6 @@ def unique_id(self): """A unique id for every expression instance.""" return next(self._ids) - def copy(self: RASPExprT) -> RASPExprT: - """Returns a shallow copy of this RASPExpr with a new ID.""" - return copy.copy(self) - @property def label(self) -> str: return f"{self.name}_{self.unique_id}" @@ -156,7 +151,7 @@ def annotated(self: RASPExprT, **annotations) -> RASPExprT: def annotate(expr: RASPExprT, **annotations) -> RASPExprT: - """Creates a new expr with added annotations.""" + """Adds annotations to an expression""" # Note that new annotations will overwrite existing ones with matching keys. expr.annotations = {**expr.annotations, **annotations} return expr From badc071e0d3dd255887fb3a8cbcaa709f1cdbb95 Mon Sep 17 00:00:00 2001 From: William Baker Date: Thu, 1 Feb 2024 16:02:14 +0000 Subject: [PATCH 3/3] removed tests for copy() and copying when annotating --- tracr/rasp/rasp_test.py | 39 +-------------------------------------- 1 file changed, 1 insertion(+), 38 deletions(-) diff --git a/tracr/rasp/rasp_test.py b/tracr/rasp/rasp_test.py index 62d33a6..eaf5535 100644 --- a/tracr/rasp/rasp_test.py +++ b/tracr/rasp/rasp_test.py @@ -309,12 +309,7 @@ def test_is_categorical(self, sop: rasp.SOp): self.assertTrue(rasp.is_categorical(rasp.categorical(sop))) self.assertFalse(rasp.is_categorical(rasp.numerical(sop))) - @parameterized.named_parameters(*_SOP_EXAMPLES()) - def test_double_encoding_annotations_overwrites_encoding(self, sop: rasp.SOp): - num_sop = rasp.numerical(sop) - cat_num_sop = rasp.categorical(num_sop) - self.assertTrue(rasp.is_numerical(num_sop)) - self.assertTrue(rasp.is_categorical(cat_num_sop)) + class SelectorTest(parameterized.TestCase): @@ -453,38 +448,6 @@ def test_constant_selector(self): ) -class CopyTest(parameterized.TestCase): - - @parameterized.named_parameters(*_ALL_EXAMPLES()) - def test_copy_preserves_name(self, expr: rasp.RASPExpr): - expr = expr.named("foo") - self.assertEqual(expr.copy().name, expr.name) - - @parameterized.named_parameters(*_ALL_EXAMPLES()) - def test_renaming_copy_doesnt_rename_original(self, expr: rasp.RASPExpr): - expr = expr.named("foo") - expr.copy().named("bar") - self.assertEqual(expr.name, "foo") - - @parameterized.named_parameters(*_ALL_EXAMPLES()) - def test_renaming_original_doesnt_rename_copy(self, expr: rasp.RASPExpr): - expr = expr.named("foo") - copy = expr.copy() - expr.named("bar") - self.assertEqual(copy.name, "foo") - - @parameterized.named_parameters(*_ALL_EXAMPLES()) - def test_copy_changes_id(self, expr: rasp.RASPExpr): - self.assertNotEqual(expr.copy().unique_id, expr.unique_id) - - @parameterized.named_parameters(*_ALL_EXAMPLES()) - def test_copy_preserves_child_ids(self, expr: rasp.RASPExpr): - copy_child_ids = [c.unique_id for c in expr.copy().children] - child_ids = [c.unique_id for c in expr.children] - for child_id, copy_child_id in zip(child_ids, copy_child_ids): - self.assertEqual(child_id, copy_child_id) - - class AggregateTest(parameterized.TestCase): """Tests for Aggregate."""