From e2f8d4554801a8adaaef92cdc2f056008d1fd664 Mon Sep 17 00:00:00 2001 From: William Baker Date: Thu, 11 Jan 2024 12:08:29 +0000 Subject: [PATCH 1/2] add checks to prevent negative outputs of map fst's that will be set to zero by the ReLU activation --- tracr/compiler/basis_inference.py | 2 ++ tracr/compiler/validating.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/tracr/compiler/basis_inference.py b/tracr/compiler/basis_inference.py index a104312..82a77e3 100644 --- a/tracr/compiler/basis_inference.py +++ b/tracr/compiler/basis_inference.py @@ -57,6 +57,8 @@ def compute_value_set(sop: rasp.SOp) -> Set[rasp.Value]: res = errors.ignoring_arithmetic_errors(sop.f)(x) if res is not None: out.add(res) + if not all(x >= 0 for x in out): + raise ValueError(f"Map does not support negative outputs due to the ReLU activation\noutputs: {out}\nsop: {sop}") return out elif isinstance(sop, rasp.SequenceMap): f_ignore_error = errors.ignoring_arithmetic_errors(sop.f) diff --git a/tracr/compiler/validating.py b/tracr/compiler/validating.py index b785852..400a412 100644 --- a/tracr/compiler/validating.py +++ b/tracr/compiler/validating.py @@ -155,6 +155,18 @@ def evaluate( ) ) + elif isinstance(expr, rasp.Map): + if not all(x >= 0 for x in out): + self.unsupported_exprs.append( + TracrUnsupportedExpr( + expr=expr, + reason=( + "Map only supports positive outputs due to the ReLU activation" + f" got {set(out)}." + ), + ) + ) + return out From 6cb1cc1792c9ab5e804e84e14458e3c23f76ea8e Mon Sep 17 00:00:00 2001 From: William Baker Date: Fri, 19 Jan 2024 12:27:48 +0000 Subject: [PATCH 2/2] add numerical check --- tracr/compiler/basis_inference.py | 2 +- tracr/compiler/validating.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tracr/compiler/basis_inference.py b/tracr/compiler/basis_inference.py index 82a77e3..5e5061f 100644 --- a/tracr/compiler/basis_inference.py +++ b/tracr/compiler/basis_inference.py @@ -57,7 +57,7 @@ def compute_value_set(sop: rasp.SOp) -> Set[rasp.Value]: res = errors.ignoring_arithmetic_errors(sop.f)(x) if res is not None: out.add(res) - if not all(x >= 0 for x in out): + if rasp.is_numerical(sop) and (not all(x >= 0 for x in out)): raise ValueError(f"Map does not support negative outputs due to the ReLU activation\noutputs: {out}\nsop: {sop}") return out elif isinstance(sop, rasp.SequenceMap): diff --git a/tracr/compiler/validating.py b/tracr/compiler/validating.py index 400a412..b0a4e0f 100644 --- a/tracr/compiler/validating.py +++ b/tracr/compiler/validating.py @@ -156,7 +156,7 @@ def evaluate( ) elif isinstance(expr, rasp.Map): - if not all(x >= 0 for x in out): + if rasp.is_numerical(expr) and (not all(x >= 0 for x in out)): self.unsupported_exprs.append( TracrUnsupportedExpr( expr=expr,