diff --git a/src/evox/problems/numerical/maf.py b/src/evox/problems/numerical/maf.py index 98a893619..50b283970 100644 --- a/src/evox/problems/numerical/maf.py +++ b/src/evox/problems/numerical/maf.py @@ -12,7 +12,7 @@ @jit def inside(x, a, b): """check if x is in [a, b] or [b, a]""" - return (a <= x <= b) | (b <= x <= a) + return (jnp.minimum(a, b) <= x) & (x <= jnp.maximum(a, b)) @jit @@ -33,7 +33,7 @@ def ray_intersect_segment(point, seg_init, seg_term): judge_2 = (LHS >= RHS) ^ (y_dist < 0) # check intersection_y, which is P_y is inside the segment judge_3 = inside(point[1], seg_init[1], seg_term[1]) - return ((y_dist == 0) & judge_1) | (y_dist != 0 & judge_2 & judge_3) + return ((y_dist == 0) & judge_1) | ((y_dist != 0) & judge_2 & judge_3) @jit