package spike;
import com.jnape.palatable.lambda.adt.choice.Choice2;
import com.jnape.palatable.lambda.adt.coproduct.CoProduct2;
import com.jnape.palatable.lambda.functions.Fn1;
import com.jnape.palatable.lambda.functions.Fn2;
import com.jnape.palatable.lambda.functions.builtin.fn2.Cons;
import com.jnape.palatable.lambda.functor.Bifunctor;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.function.Function;
import static com.jnape.palatable.lambda.adt.Maybe.just;
import static com.jnape.palatable.lambda.adt.Maybe.nothing;
import static com.jnape.palatable.lambda.adt.choice.Choice2.a;
import static com.jnape.palatable.lambda.adt.choice.Choice2.b;
import static com.jnape.palatable.lambda.adt.hlist.Tuple2.fill;
import static com.jnape.palatable.lambda.functions.builtin.fn1.Id.id;
import static com.jnape.palatable.lambda.functions.builtin.fn1.Last.last;
import static com.jnape.palatable.lambda.functions.builtin.fn2.Cons.cons;
import static com.jnape.palatable.lambda.functions.builtin.fn2.Partition.partition;
import static com.jnape.palatable.lambda.functions.builtin.fn2.Unfoldr.unfoldr;
import static spike.Spike.ContinuationStackFrame.recurse;
import static spike.Spike.ContinuationStackFrame.terminate;
import static spike.Spike.Corecursive.corecursive;
import static spike.Spike.OptimizedRecursiveCallStack.optimizedRecursiveCallStack;
import static spike.Spike.Recursive.recursive;
import static spike.Spike.Trampoline.trampoline;
public class Spike {
public static abstract class ContinuationStackFrame implements CoProduct2>, Bifunctor> {
@Override
public ContinuationStackFrame biMapL(Function super A, ? extends C> fn) {
throw new UnsupportedOperationException();
}
@Override
public ContinuationStackFrame biMapR(Function super B, ? extends C> fn) {
throw new UnsupportedOperationException();
}
@Override
public ContinuationStackFrame biMap(Function super A, ? extends C> lFn,
Function super B, ? extends D> rFn) {
throw new UnsupportedOperationException();
}
public static ContinuationStackFrame recurse(A a) {
return new Recurse<>(a);
}
public static ContinuationStackFrame terminate(B b) {
return new Terminate<>(b);
}
private static final class Recurse extends ContinuationStackFrame {
private final A a;
private Recurse(A a) {
this.a = a;
}
@Override
public R match(Function super A, ? extends R> aFn, Function super B, ? extends R> bFn) {
return aFn.apply(a);
}
@Override
public String toString() {
return "Recurse{" +
"a=" + a +
'}';
}
}
private static final class Terminate extends ContinuationStackFrame {
private final B b;
private Terminate(B b) {
this.b = b;
}
@Override
public R match(Function super A, ? extends R> aFn, Function super B, ? extends R> bFn) {
return bFn.apply(b);
}
@Override
public String toString() {
return "Terminate{" +
"b=" + b +
'}';
}
}
}
public static final class Trampoline implements Fn2>, A, B> {
private static final Trampoline INSTANCE = new Trampoline<>();
@Override
public B apply(Function super A, ? extends CoProduct2> fn, A a) {
CoProduct2 extends A, ? extends B, ?> next = fn.apply(a);
while (next.match(__ -> true, __ -> false))
next = fn.apply(next.match(id(), __ -> null));
return next.match(__ -> null, id());
}
@SuppressWarnings("unchecked")
public static Trampoline trampoline() {
return INSTANCE;
}
public static Fn1 trampoline(Function super A, ? extends CoProduct2> fn) {
return Trampoline.trampoline().apply(fn);
}
public static B trampoline(Function super A, ? extends CoProduct2> fn, A a) {
return trampoline(fn).apply(a);
}
}
public interface Recursive extends Fn1> {
default Fn1> unroll() {
return a -> optimizedRecursiveCallStack(cons(recurse(a), fill(apply(a))
.fmap(unfoldr(tc -> tc.match(next -> just(fill(apply(next))), __ -> nothing())))
.into(Cons::cons)));
}
static Recursive recursive(
Function super A, ? extends CoProduct2 extends A, ? extends B, ?>> f) {
return a -> f.apply(a).match(ContinuationStackFrame::recurse, ContinuationStackFrame::terminate);
}
default Fn1 trampoline() {
return Trampoline.trampoline(this);
}
}
public static final class Corecursive implements Recursive, C> {
private final Recursive> f;
private final Recursive> g;
private Corecursive(Recursive> f,
Recursive> g) {
this.f = f;
this.g = g;
}
@Override
public ContinuationStackFrame, C> apply(CoProduct2 ab) {
return ab.match(a -> f.apply(a).match(nextA -> recurse(Choice2.a(nextA)),
bc -> bc.match(b -> recurse(Choice2.b(b)), ContinuationStackFrame::terminate)),
b -> g.apply(b).match(nextB -> recurse(Choice2.b(nextB)),
ac -> ac.match(a -> recurse(Choice2.a(a)), ContinuationStackFrame::terminate)));
}
public static Corecursive corecursive(
Recursive> f,
Recursive> g
) {
return new Corecursive<>(f, g);
}
}
public static void main(String... args) {
Recursive> evens = recursive(i -> i == 1
? terminate(terminate(i))
: i % 2 == 0
? recurse(i / 2)
: terminate(recurse(i)));
Recursive> odds = recursive(i -> i == 1 ? b(b(i)) : i % 2 == 1 ? a((i * 3) + 1) : b(a(i)));
Corecursive conjecture = corecursive(evens, odds);
trampoline(corecursive(recursive(evens), recursive(odds)));
Choice2 arg = Choice2.a(9);
System.out.println(optimizedRecursiveCallStack(conjecture.unroll().apply(arg)).roll());
System.out.println(trampoline(conjecture).apply(arg));
}
public static final class OptimizedRecursiveCallStack implements Iterable> {
private final Iterable> stack;
private OptimizedRecursiveCallStack(Iterable> stack) {
this.stack = stack;
}
public B roll() {
return last(partition(id(), stack)._2()).orElseThrow(NoSuchElementException::new);
}
@Override
public Iterator> iterator() {
return stack.iterator();
}
public static OptimizedRecursiveCallStack optimizedRecursiveCallStack(
Iterable> stack) {
return new OptimizedRecursiveCallStack<>(stack);
}
}
}