Skip to content

Commit

Permalink
Eliminate some safeRecurse inner classes with lambdas
Browse files Browse the repository at this point in the history
  • Loading branch information
headius committed Jan 3, 2024
1 parent 18432dc commit 6f670cb
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 119 deletions.
28 changes: 9 additions & 19 deletions core/src/main/java/org/jruby/RubyComparable.java
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,18 @@ public static IRubyObject cmperr(IRubyObject recv, IRubyObject other) {
*
*/
public static IRubyObject invcmp(final ThreadContext context, final IRubyObject recv, final IRubyObject other) {
return invcmp(context, DEFAULT_INVCMP, recv, other);
return invcmp(context, RubyComparable::invcmpRecursive, recv, other);
}

private static final ThreadContext.RecursiveFunctionEx DEFAULT_INVCMP = new ThreadContext.RecursiveFunctionEx<IRubyObject>() {
@Override
public IRubyObject call(ThreadContext context, IRubyObject recv, IRubyObject other, boolean recur) {
if (recur || !sites(context).respond_to_op_cmp.respondsTo(context, other, other)) return context.nil;
return sites(context).op_cmp.call(context, other, other, recv);
}
};
private static IRubyObject invcmpRecursive(ThreadContext context, IRubyObject recv, IRubyObject other, boolean recur) {
if (recur || !sites(context).respond_to_op_cmp.respondsTo(context, other, other)) return context.nil;
return sites(context).op_cmp.call(context, other, other, recv);
}

/** rb_invcmp
*
*/
public static IRubyObject invcmp(final ThreadContext context, ThreadContext.RecursiveFunctionEx func, IRubyObject recv, IRubyObject other) {
public static IRubyObject invcmp(final ThreadContext context, ThreadContext.RecursiveFunctionEx<IRubyObject> func, IRubyObject recv, IRubyObject other) {
IRubyObject result = context.safeRecurse(func, recv, other, "<=>", true);

if (result.isNil()) return result;
Expand Down Expand Up @@ -167,7 +164,9 @@ private static IRubyObject callCmpMethod(final ThreadContext context, final IRub

if (recv == other) return context.tru;

IRubyObject result = context.safeRecurse(CMP_RECURSIVE, other, recv, "<=>", true);
IRubyObject result = context.safeRecurse(
(ctx, obj, self, recur) -> recur ? ctx.nil : sites(ctx).op_cmp.call(ctx, self, self, obj),
other, recv, "<=>", true);

// This is only to prevent throwing exceptions by cmperr - it has poor performance
if ( result.isNil() ) return returnValueOnError;
Expand Down Expand Up @@ -288,13 +287,4 @@ private static ComparableSites sites(ThreadContext context) {
return context.sites.Comparable;
}

private static class CmpRecursive implements ThreadContext.RecursiveFunctionEx<IRubyObject> {
@Override
public IRubyObject call(ThreadContext context, IRubyObject other, IRubyObject self, boolean recur) {
if (recur) return context.nil;
return sites(context).op_cmp.call(context, self, self, other);
}
}

private static final CmpRecursive CMP_RECURSIVE = new CmpRecursive();
}
103 changes: 52 additions & 51 deletions core/src/main/java/org/jruby/RubyNumeric.java
Original file line number Diff line number Diff line change
Expand Up @@ -605,24 +605,27 @@ protected final IRubyObject coerceBit(ThreadContext context, String method, IRub
}

protected final IRubyObject coerceBit(ThreadContext context, JavaSites.CheckedSites site, IRubyObject other) {
RubyArray ary = doCoerce(context, other, true);
final IRubyObject x = ary.eltOk(0);
IRubyObject y = ary.eltOk(1);
IRubyObject ret = context.safeRecurse(new ThreadContext.RecursiveFunctionEx<JavaSites.CheckedSites>() {
@Override
public IRubyObject call(ThreadContext context, JavaSites.CheckedSites site, IRubyObject obj, boolean recur) {
if (recur) {
throw context.runtime.newNameError(str(context.runtime, "recursive call to ", ids(context.runtime, site.methodName)), context.runtime.newSymbol(site.methodName));
}
return getMetaClass(x).finvokeChecked(context, x, site, obj);
}
}, site, y, site.methodName, true);
IRubyObject ret = context.safeRecurse(RubyNumeric::coerceBitRecursive, site, doCoerce(context, other, true), site.methodName, true);
if (ret == null) {
coerceFailed(context, other);
}
return ret;
}

private static IRubyObject coerceBitRecursive(ThreadContext context, JavaSites.CheckedSites site, IRubyObject _array, boolean recur) {
if (recur) {
Ruby runtime = context.runtime;
String methodName = site.methodName;
throw runtime.newNameError(str(runtime, "recursive call to ", ids(runtime, methodName)), runtime.newSymbol(methodName));
}

RubyArray array = (RubyArray) _array;
IRubyObject x = array.eltOk(0);
IRubyObject y = array.eltOk(1);

return getMetaClass(x).finvokeChecked(context, x, site, y);
}

/** rb_num_coerce_cmp
* coercion used for comparisons
*/
Expand Down Expand Up @@ -1649,13 +1652,49 @@ public IRubyObject dup() {
}

public static IRubyObject numFuncall(ThreadContext context, IRubyObject x, CallSite site) {
return context.safeRecurse(new NumFuncall0(), site, x, site.methodName, true);
return context.safeRecurse(RubyNumeric::numFuncall0, site, x, site.methodName, true);
}

public static IRubyObject numFuncall(ThreadContext context, final IRubyObject x, CallSite site, final IRubyObject value) {
return context.safeRecurse(new NumFuncall1(value), site, x, site.methodName, true);
}

private static IRubyObject numFuncall0(ThreadContext context, CallSite site, IRubyObject obj, boolean recur) {
if (recur) {
String name = site.methodName;
if (name.length() > 0 && Character.isLetterOrDigit(name.charAt(0))) {
throw context.runtime.newNameError(name, obj, name);
} else if (name.length() == 2 && name.charAt(1) == '@') {
throw context.runtime.newNameError(name, obj, name.substring(0, 1));
} else {
throw context.runtime.newNameError(name, obj, name);
}
}
return site.call(context, obj, obj);
}

private static class NumFuncall1 implements ThreadContext.RecursiveFunctionEx<CallSite> {
private final IRubyObject value;

public NumFuncall1(IRubyObject value) {
this.value = value;
}

@Override
public IRubyObject call(ThreadContext context, CallSite site, IRubyObject obj, boolean recur) {
if (recur) {
String name = site.methodName;
Ruby runtime = context.runtime;
if (name.length() > 0 && Character.isLetterOrDigit(name.charAt(0))) {
throw runtime.newNameError(name, obj, name);
} else {
throw runtime.newNameError(name, obj, name);
}
}
return site.call(context, obj, obj, value);
}
}

// MRI: macro FIXABLE, RB_FIXABLE
// Note: this does additional checks for inf and nan
public static boolean fixable(Ruby runtime, double f) {
Expand Down Expand Up @@ -1690,44 +1729,6 @@ public static boolean negFixable(double l) {
return l >= RubyFixnum.MIN;
}

private static class NumFuncall1 implements ThreadContext.RecursiveFunctionEx<CallSite> {
private final IRubyObject value;

public NumFuncall1(IRubyObject value) {
this.value = value;
}

@Override
public IRubyObject call(ThreadContext context, CallSite site, IRubyObject obj, boolean recur) {
if (recur) {
String name = site.methodName;
if (name.length() > 0 && Character.isLetterOrDigit(name.charAt(0))) {
throw context.runtime.newNameError(name, obj, name);
} else {
throw context.runtime.newNameError(name, obj, name);
}
}
return site.call(context, obj, obj, value);
}
}

private static class NumFuncall0 implements ThreadContext.RecursiveFunctionEx<CallSite> {
@Override
public IRubyObject call(ThreadContext context, CallSite site, IRubyObject obj, boolean recur) {
if (recur) {
String name = site.methodName;
if (name.length() > 0 && Character.isLetterOrDigit(name.charAt(0))) {
throw context.runtime.newNameError(name, obj, name);
} else if (name.length() == 2 && name.charAt(1) == '@') {
throw context.runtime.newNameError(name, obj, name.substring(0,1));
} else {
throw context.runtime.newNameError(name, obj, name);
}
}
return site.call(context, obj, obj);
}
}

@Deprecated
public IRubyObject floor() {
return floor(getRuntime().getCurrentContext());
Expand Down
16 changes: 6 additions & 10 deletions core/src/main/java/org/jruby/RubyRange.java
Original file line number Diff line number Diff line change
Expand Up @@ -336,20 +336,16 @@ public RubyFixnum hash(ThreadContext context) {
}

private static RubyString inspectValue(final ThreadContext context, IRubyObject value) {
return (RubyString) context.safeRecurse(INSPECT_RECURSIVE, value, value, "inspect", true);
return (RubyString) context.safeRecurse(RubyRange::inspectValueRecursive, value, value, "inspect", true);
}

private static class InspectRecursive implements ThreadContext.RecursiveFunctionEx<IRubyObject> {
@Override
public IRubyObject call(ThreadContext context, IRubyObject state, IRubyObject obj, boolean recur) {
if (recur) {
return RubyString.newString(context.runtime, ((RubyRange) obj).isExclusive ? "(... ... ...)" : "(... .. ...)");
} else {
return inspect(context, obj);
}
private static IRubyObject inspectValueRecursive(ThreadContext context, IRubyObject state, IRubyObject obj, boolean recur) {
if (recur) {
return RubyString.newString(context.runtime, ((RubyRange) obj).isExclusive ? "(... ... ...)" : "(... .. ...)");
} else {
return inspect(context, obj);
}
}
private static final InspectRecursive INSPECT_RECURSIVE = new InspectRecursive();

private static final byte[] DOTDOTDOT = new byte[]{'.', '.', '.'};

Expand Down
26 changes: 4 additions & 22 deletions core/src/main/java/org/jruby/RubyStruct.java
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ public RubyFixnum hash(ThreadContext context) {
IRubyObject[] values = this.values;
for (int i = 0; i < values.length; i++) {
h = (h << 1) | (h < 0 ? 1 : 0);
IRubyObject hash = context.safeRecurse(HashRecursive.INSTANCE, runtime, values[i], "hash", true);
IRubyObject hash = context.safeRecurse(
(ctx, runtime1, obj, recur) -> recur ? RubyFixnum.zero(runtime1) : invokedynamic(ctx, obj, HASH),
runtime, values[i], "hash", true);
h ^= RubyNumeric.num2long(hash);
}

Expand Down Expand Up @@ -696,7 +698,7 @@ else if (first != '#') {
@JRubyMethod(name = {"inspect", "to_s"})
public RubyString inspect(final ThreadContext context) {
// recursion guard
return (RubyString) context.safeRecurse(InspectRecursive.INSTANCE, this, this, "inspect", false);
return (RubyString) context.safeRecurse((ctx, self, obj, recur) -> self.inspectStruct(ctx, recur), this, this, "inspect", false);
}

@JRubyMethod(name = {"to_a", "deconstruct", "values"})
Expand Down Expand Up @@ -1032,17 +1034,6 @@ public IRubyObject call(ThreadContext context, IRubyObject other, IRubyObject se
}
}

private static class HashRecursive implements ThreadContext.RecursiveFunctionEx<Ruby> {

static final HashRecursive INSTANCE = new HashRecursive();

@Override
public IRubyObject call(ThreadContext context, Ruby runtime, IRubyObject obj, boolean recur) {
if (recur) return RubyFixnum.zero(runtime);
return invokedynamic(context, obj, HASH);
}
}

private static class EqualRecursive implements ThreadContext.RecursiveFunctionEx<IRubyObject> {

private static final EqualRecursive INSTANCE = new EqualRecursive();
Expand All @@ -1060,15 +1051,6 @@ public IRubyObject call(ThreadContext context, IRubyObject other, IRubyObject se
}
}

private static class InspectRecursive implements ThreadContext.RecursiveFunctionEx<RubyStruct> {

private static final ThreadContext.RecursiveFunctionEx INSTANCE = new InspectRecursive();

public IRubyObject call(ThreadContext context, RubyStruct self, IRubyObject obj, boolean recur) {
return self.inspectStruct(context, recur);
}
}

private static StructSites sites(ThreadContext context) {
return context.sites.Struct;
}
Expand Down
26 changes: 9 additions & 17 deletions core/src/main/java/org/jruby/runtime/JavaSites.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,9 @@ public static class StringSites {
public final CallSite op_and = new FunctionalCachingCallSite("&");
public final CheckedSites to_hash_checked = new CheckedSites("to_hash");

public final ThreadContext.RecursiveFunctionEx recursive_cmp = new ThreadContext.RecursiveFunctionEx<IRubyObject>() {
@Override
public IRubyObject call(ThreadContext context, IRubyObject recv, IRubyObject other, boolean recur) {
if (recur || !respond_to_cmp.respondsTo(context, other, other)) return context.nil;
return cmp.call(context, other, other, recv);
}
public final ThreadContext.RecursiveFunctionEx<IRubyObject> recursive_cmp = (context, recv, other, recur) -> {
if (recur || !respond_to_cmp.respondsTo(context, other, other)) return context.nil;
return cmp.call(context, other, other, recv);
};
}

Expand Down Expand Up @@ -319,12 +316,9 @@ public static class TimeSites {
public final RespondToCallSite respond_to_cmp = new RespondToCallSite("<=>");
public final CachingCallSite cmp = new FunctionalCachingCallSite("<=>");

public final ThreadContext.RecursiveFunctionEx recursive_cmp = new ThreadContext.RecursiveFunctionEx<IRubyObject>() {
@Override
public IRubyObject call(ThreadContext context, IRubyObject recv, IRubyObject other, boolean recur) {
if (recur || !respond_to_cmp.respondsTo(context, other, other)) return context.nil;
return cmp.call(context, other, other, recv);
}
public final ThreadContext.RecursiveFunctionEx<IRubyObject> recursive_cmp = (context, recv, other, recur) -> {
if (recur || !respond_to_cmp.respondsTo(context, other, other)) return context.nil;
return cmp.call(context, other, other, recv);
};

public final RespondToCallSite respond_to_to_int = new RespondToCallSite("to_int");
Expand Down Expand Up @@ -400,11 +394,9 @@ public static class HelpersSites {
public final CallSite hash = new FunctionalCachingCallSite("hash");
public final CallSite op_equal = new FunctionalCachingCallSite("==");

public final ThreadContext.RecursiveFunctionEx<Ruby> recursive_hash = new ThreadContext.RecursiveFunctionEx<Ruby>() {
public IRubyObject call(ThreadContext context, Ruby runtime, IRubyObject obj, boolean recur) {
if (recur) return RubyFixnum.zero(runtime);
return hash.call(context, obj, obj);
}
public final ThreadContext.RecursiveFunctionEx<Ruby> recursive_hash = (context, runtime, obj, recur) -> {
if (recur) return RubyFixnum.zero(runtime);
return hash.call(context, obj, obj);
};
}

Expand Down

0 comments on commit 6f670cb

Please sign in to comment.