Skip to content
This repository was archived by the owner on Apr 10, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions tracr/rasp/rasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

import abc
import collections.abc
import copy
import enum
import functools
import itertools
Expand Down Expand Up @@ -138,10 +137,6 @@ def unique_id(self):
"""A unique id for every expression instance."""
return next(self._ids)

def copy(self: RASPExprT) -> RASPExprT:
"""Returns a shallow copy of this RASPExpr with a new ID."""
return copy.copy(self)

@property
def label(self) -> str:
return f"{self.name}_{self.unique_id}"
Expand All @@ -156,11 +151,10 @@ def annotated(self: RASPExprT, **annotations) -> RASPExprT:


def annotate(expr: RASPExprT, **annotations) -> RASPExprT:
"""Creates a new expr with added annotations."""
new = expr.copy()
"""Adds annotations to an expression"""
# Note that new annotations will overwrite existing ones with matching keys.
new.annotations = {**expr.annotations, **annotations}
return new
expr.annotations = {**expr.annotations, **annotations}
return expr


### S-Ops.
Expand Down
39 changes: 1 addition & 38 deletions tracr/rasp/rasp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,7 @@ def test_is_categorical(self, sop: rasp.SOp):
self.assertTrue(rasp.is_categorical(rasp.categorical(sop)))
self.assertFalse(rasp.is_categorical(rasp.numerical(sop)))

@parameterized.named_parameters(*_SOP_EXAMPLES())
def test_double_encoding_annotations_overwrites_encoding(self, sop: rasp.SOp):
num_sop = rasp.numerical(sop)
cat_num_sop = rasp.categorical(num_sop)
self.assertTrue(rasp.is_numerical(num_sop))
self.assertTrue(rasp.is_categorical(cat_num_sop))



class SelectorTest(parameterized.TestCase):
Expand Down Expand Up @@ -453,38 +448,6 @@ def test_constant_selector(self):
)


class CopyTest(parameterized.TestCase):

@parameterized.named_parameters(*_ALL_EXAMPLES())
def test_copy_preserves_name(self, expr: rasp.RASPExpr):
expr = expr.named("foo")
self.assertEqual(expr.copy().name, expr.name)

@parameterized.named_parameters(*_ALL_EXAMPLES())
def test_renaming_copy_doesnt_rename_original(self, expr: rasp.RASPExpr):
expr = expr.named("foo")
expr.copy().named("bar")
self.assertEqual(expr.name, "foo")

@parameterized.named_parameters(*_ALL_EXAMPLES())
def test_renaming_original_doesnt_rename_copy(self, expr: rasp.RASPExpr):
expr = expr.named("foo")
copy = expr.copy()
expr.named("bar")
self.assertEqual(copy.name, "foo")

@parameterized.named_parameters(*_ALL_EXAMPLES())
def test_copy_changes_id(self, expr: rasp.RASPExpr):
self.assertNotEqual(expr.copy().unique_id, expr.unique_id)

@parameterized.named_parameters(*_ALL_EXAMPLES())
def test_copy_preserves_child_ids(self, expr: rasp.RASPExpr):
copy_child_ids = [c.unique_id for c in expr.copy().children]
child_ids = [c.unique_id for c in expr.children]
for child_id, copy_child_id in zip(child_ids, copy_child_ids):
self.assertEqual(child_id, copy_child_id)


class AggregateTest(parameterized.TestCase):
"""Tests for Aggregate."""

Expand Down