diff --git a/tracr/rasp/rasp.py b/tracr/rasp/rasp.py index a074019..d943d8a 100644 --- a/tracr/rasp/rasp.py +++ b/tracr/rasp/rasp.py @@ -34,7 +34,6 @@ import abc import collections.abc -import copy import enum import functools import itertools @@ -138,10 +137,6 @@ def unique_id(self): """A unique id for every expression instance.""" return next(self._ids) - def copy(self: RASPExprT) -> RASPExprT: - """Returns a shallow copy of this RASPExpr with a new ID.""" - return copy.copy(self) - @property def label(self) -> str: return f"{self.name}_{self.unique_id}" @@ -156,11 +151,10 @@ def annotated(self: RASPExprT, **annotations) -> RASPExprT: def annotate(expr: RASPExprT, **annotations) -> RASPExprT: - """Creates a new expr with added annotations.""" - new = expr.copy() + """Adds annotations to an expression""" # Note that new annotations will overwrite existing ones with matching keys. - new.annotations = {**expr.annotations, **annotations} - return new + expr.annotations = {**expr.annotations, **annotations} + return expr ### S-Ops. diff --git a/tracr/rasp/rasp_test.py b/tracr/rasp/rasp_test.py index 62d33a6..eaf5535 100644 --- a/tracr/rasp/rasp_test.py +++ b/tracr/rasp/rasp_test.py @@ -309,12 +309,7 @@ def test_is_categorical(self, sop: rasp.SOp): self.assertTrue(rasp.is_categorical(rasp.categorical(sop))) self.assertFalse(rasp.is_categorical(rasp.numerical(sop))) - @parameterized.named_parameters(*_SOP_EXAMPLES()) - def test_double_encoding_annotations_overwrites_encoding(self, sop: rasp.SOp): - num_sop = rasp.numerical(sop) - cat_num_sop = rasp.categorical(num_sop) - self.assertTrue(rasp.is_numerical(num_sop)) - self.assertTrue(rasp.is_categorical(cat_num_sop)) + class SelectorTest(parameterized.TestCase): @@ -453,38 +448,6 @@ def test_constant_selector(self): ) -class CopyTest(parameterized.TestCase): - - @parameterized.named_parameters(*_ALL_EXAMPLES()) - def test_copy_preserves_name(self, expr: rasp.RASPExpr): - expr = expr.named("foo") - self.assertEqual(expr.copy().name, expr.name) - - @parameterized.named_parameters(*_ALL_EXAMPLES()) - def test_renaming_copy_doesnt_rename_original(self, expr: rasp.RASPExpr): - expr = expr.named("foo") - expr.copy().named("bar") - self.assertEqual(expr.name, "foo") - - @parameterized.named_parameters(*_ALL_EXAMPLES()) - def test_renaming_original_doesnt_rename_copy(self, expr: rasp.RASPExpr): - expr = expr.named("foo") - copy = expr.copy() - expr.named("bar") - self.assertEqual(copy.name, "foo") - - @parameterized.named_parameters(*_ALL_EXAMPLES()) - def test_copy_changes_id(self, expr: rasp.RASPExpr): - self.assertNotEqual(expr.copy().unique_id, expr.unique_id) - - @parameterized.named_parameters(*_ALL_EXAMPLES()) - def test_copy_preserves_child_ids(self, expr: rasp.RASPExpr): - copy_child_ids = [c.unique_id for c in expr.copy().children] - child_ids = [c.unique_id for c in expr.children] - for child_id, copy_child_id in zip(child_ids, copy_child_ids): - self.assertEqual(child_id, copy_child_id) - - class AggregateTest(parameterized.TestCase): """Tests for Aggregate."""