diff --git a/tracr/compiler/basis_inference.py b/tracr/compiler/basis_inference.py index a104312..5e5061f 100644 --- a/tracr/compiler/basis_inference.py +++ b/tracr/compiler/basis_inference.py @@ -57,6 +57,8 @@ def compute_value_set(sop: rasp.SOp) -> Set[rasp.Value]: res = errors.ignoring_arithmetic_errors(sop.f)(x) if res is not None: out.add(res) + if rasp.is_numerical(sop) and (not all(x >= 0 for x in out)): + raise ValueError(f"Map does not support negative outputs due to the ReLU activation\noutputs: {out}\nsop: {sop}") return out elif isinstance(sop, rasp.SequenceMap): f_ignore_error = errors.ignoring_arithmetic_errors(sop.f) diff --git a/tracr/compiler/validating.py b/tracr/compiler/validating.py index b785852..b0a4e0f 100644 --- a/tracr/compiler/validating.py +++ b/tracr/compiler/validating.py @@ -155,6 +155,18 @@ def evaluate( ) ) + elif isinstance(expr, rasp.Map): + if rasp.is_numerical(expr) and (not all(x >= 0 for x in out)): + self.unsupported_exprs.append( + TracrUnsupportedExpr( + expr=expr, + reason=( + "Map only supports positive outputs due to the ReLU activation" + f" got {set(out)}." + ), + ) + ) + return out