simple tail call optimization for Java
enables infinitely deep tail recursive calls without throwing a StackOverflowError
no transitive dependencies
add the jitpack repository
<repositories>
...
<repository>
<id>jitpack.io</id>
<url>https://jitpack.io</url>
</repository>
...
</repositories>add the dependency
<dependencies>
...
<dependency>
<groupId>com.github.nrktkt</groupId>
<artifactId>tail</artifactId>
<version>Tag</version>
</dependency>
...
</dependencies>import com.github.kag0.tail.Tail;
import static com.github.kag0.tail.Tail.*;
Tail<Void> infiniteLoop(int i) {
System.out.println("Loop " + i + ", stack still intact!");
return call(() -> infiniteLoop(i + 1));
}
infiniteLoop(0).evaluate();let's start with a simple recursive method to compute the nth factorial.
this code will throw a StackOverflowError for large values of n.
long factorial(long n) {
if(n == 1) return 1;
else return n * factorial(n - 1);
}the tail position is just another way of saying
"the last thing you do before the return".
long factorial(long fact, long n) {
if(n.equals(1)) return fact;
return factorial(fact * n, n - 1);
}this may require a slight refactor, usually to add an additional parameter to accumulate progress.
this will enforce that the recursive call is in the tail position.
Tail<Long> factorial(long fact, long n)if(n.equals(0)) return done(fact);return call(() -> factorial(fact * n, n - 1));call .evaluate() on the invocation of your method.
factorial(1, Long.MAX_VALUE).evaluate();recursive methods no longer blow the stack.
note that if you skip the 'move the recursive call into the tail position'
step, the code will not compile because the method is not tail recursive
and therefore not stack safe. thanks to Tail that is covered by type safety.
in addition to making tail recursion safe, we can also use trampolining to enable recursive methods that would otherwise be tricky to make tail recursive.
to do this, just use .flatMap to chain two calls together.
for example
Tail<Integer> ackermann(int m, int n) {
if(m == 0)
return done(n + 1);
if(m > 0 && n == 0)
return call(() -> ackermann(m - 1, 1));
if(m > 0 && n > 0)
return call(() -> ackermann(m, n - 1)).flatMap(nn -> ackermann(m - 1, nn));
throw new IllegalArgumentException();
}