import kotlin.coroutines.* import kotlin.coroutines.intrinsics.* import kotlin.math.* /** * Defines deep recursive function that keeps its stack on the heap, * which allows very deep recursive computations that do not use the actual call stack. * To initiate a call to this deep recursive function use its [invoke] function. * As a rule of thumb, it should be used if recursion goes deeper than a thousand calls. * * The [DeepRecursiveFunction] takes one parameter of type [T] and returns a result of type [R]. * The [block] of code defines the body of a recursive function. In this block * [callRecursive][DeepRecursiveScope.callRecursive] function can be used to make a recursive call * to the declared function. Other instances of [DeepRecursiveFunction] can be called * in this scope with `callRecursive` extension, too. * * For example, take a look at the following recursive tree class and a deeply * recursive instance of this tree with a million nodes: * * ``` * class Tree(val left: Tree? = null, val right: Tree? = null) * val deepTree = generateSequence(Tree()) { Tree(it) }.take(1_000_000).last() * ``` * * A regular recursive function can be defined to compute a depth of a tree: * * ``` * fun depth(t: Tree?): Int = * if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1 * println(depth(deepTree)) // StackOverflowError * ``` * * If this `depth` function is called for a `deepTree` it produces [StackOverflowError] because of deep recursion. * However, the `depth` function can be rewritten using `DeepRecursiveFunction` in the following way, and then * it successfully computes [`depth(deepTree)`][DeepRecursiveFunction.invoke] expression: * * ``` * val depth = DeepRecursiveFunction { t -> * if (t == null) 0 else max(callRecursive(t.left), callRecursive(t.right)) + 1 * } * println(depth(deepTree)) // Ok * ``` * * Deep recursive functions can also mutually call each other using a heap for the stack via * [callRecursive][DeepRecursiveScope.callRecursive] extension. For example, the * following pair of mutually recursive functions computes the number of tree nodes at even depth in the tree. * * ``` * val mutualRecursion = object { * val even: DeepRecursiveFunction = DeepRecursiveFunction { t -> * if (t == null) 0 else odd.callRecursive(t.left) + odd.callRecursive(t.right) + 1 * } * val odd: DeepRecursiveFunction = DeepRecursiveFunction { t -> * if (t == null) 0 else even.callRecursive(t.left) + even.callRecursive(t.right) * } * } * ``` * * @param [T] the function parameter type. * @param [R] the function result type. * @param block the function body. */ public class DeepRecursiveFunction( internal val block: suspend DeepRecursiveScope.(T) -> R ) /** * Initiates a call to this deep recursive function, forming a root of the call tree. * * This operator should not be used from inside of [DeepRecursiveScope] as it uses the call stack slot for * initial recursive invocation. From inside of [DeepRecursiveScope] use * [callRecursive][DeepRecursiveScope.callRecursive]. */ public operator fun DeepRecursiveFunction.invoke(value: T): R = DeepRecursiveScopeImpl(block, value).runCallLoop() /** * A scope class for [DeepRecursiveFunction] function declaration that defines [callRecursive] methods to * recursively call this function or another [DeepRecursiveFunction] putting the call activation frame on the heap. * * @param [T] function parameter type. * @param [R] function result type. */ @RestrictsSuspension public sealed class DeepRecursiveScope { /** * Makes recursive call to this [DeepRecursiveFunction] function putting the call activation frame on the heap, * as opposed to the actual call stack that is used by a regular recursive call. */ public abstract suspend fun callRecursive(value: T): R /** * Makes call to the specified [DeepRecursiveFunction] function putting the call activation frame on the heap, * as opposed to the actual call stack that is used by a regular call. */ public abstract suspend fun DeepRecursiveFunction.callRecursive(value: U): S @Deprecated( level = DeprecationLevel.ERROR, message = "'invoke' should not be called from DeepRecursiveScope. " + "Use 'callRecursive' to do recursion in the heap instead of the call stack.", replaceWith = ReplaceWith("this.callRecursive(value)") ) public operator fun DeepRecursiveFunction<*, *>.invoke(value: Any?): Nothing = throw UnsupportedOperationException("Should not be called from DeepRecursiveScope") } // ================== Implementation ================== private typealias DeepRecursiveFunctionBlock = Function3?, Any?> private val UNDEFINED_RESULT = Result.success(COROUTINE_SUSPENDED) @Suppress("UNCHECKED_CAST") private class DeepRecursiveScopeImpl( block: suspend DeepRecursiveScope.(T) -> R, value: T ) : DeepRecursiveScope(), Continuation { // Active function block private var function: DeepRecursiveFunctionBlock = block as DeepRecursiveFunctionBlock // Value to call function with private var value: Any? = value // Continuation of the current call private var cont: Continuation? = this as Continuation // Completion result (completion of the whole call stack) private var result: Result = UNDEFINED_RESULT override val context: CoroutineContext get() = EmptyCoroutineContext override fun resumeWith(result: Result) { this.cont = null this.result = result } override suspend fun callRecursive(value: T): R = suspendCoroutineUninterceptedOrReturn { cont -> // calling the same function that is currently active this.cont = cont as Continuation this.value = value COROUTINE_SUSPENDED } override suspend fun DeepRecursiveFunction.callRecursive(value: U): S = suspendCoroutineUninterceptedOrReturn { cont -> // calling another recursive function val function = block as DeepRecursiveFunctionBlock with(this@DeepRecursiveScopeImpl) { val currentFunction = this.function if (function !== currentFunction) { // calling a different function -- create a trampoline to restore function ref this.function = function this.cont = crossFunctionCompletion(currentFunction, cont as Continuation) } else { // calling the same function -- direct this.cont = cont as Continuation } this.value = value } COROUTINE_SUSPENDED } private fun crossFunctionCompletion( currentFunction: DeepRecursiveFunctionBlock, cont: Continuation ): Continuation = Continuation(EmptyCoroutineContext) { this.function = currentFunction // When going back from a trampoline we cannot just call cont.resume (stack usage!) // We delegate the cont.resumeWith(it) call to runCallLoop this.cont = cont this.result = it } @Suppress("UNCHECKED_CAST") fun runCallLoop(): R { while (true) { // Note: cont is set to null in DeepRecursiveScopeImpl.resumeWith when the whole computation completes val result = this.result val cont = this.cont ?: return (result as Result).getOrThrow() // done -- final result // The order of comparison is important here for that case of rogue class with broken equals if (UNDEFINED_RESULT == result) { // call "function" with "value" using "cont" as completion val r = try { // This is block.startCoroutine(this, value, cont) function(this, value, cont) } catch (e: Throwable) { cont.resumeWithException(e) continue } // If the function returns without suspension -- calls its continuation immediately if (r !== COROUTINE_SUSPENDED) cont.resume(r as R) } else { // we returned from a crossFunctionCompletion trampoline -- call resume here this.result = UNDEFINED_RESULT // reset result back cont.resumeWith(result) } } } } // ============== Test code ============== class Tree(val left: Tree? = null, val right: Tree? = null) fun main() { val n = 1_000_000 val deepTree = generateSequence(Tree()) { Tree(it) }.take(n).last() fun binaryTree(k: Int): Tree? = if (k == 0) null else Tree(binaryTree(k - 1), binaryTree(k - 1)) val k = 20 val binaryTree = binaryTree(k) // Regular code: Stack overflow error // fun depth(t: Tree?): Int = // if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1 val depth = DeepRecursiveFunction { t -> if (t == null) 0 else max( callRecursive(t.left), callRecursive(t.right) ) + 1 } // A pair of mutually-recursive functions val mutualRecursion = object { val even: DeepRecursiveFunction = DeepRecursiveFunction { t -> if (t == null) 0 else odd.callRecursive(t.left) + odd.callRecursive(t.right) + 1 } val odd: DeepRecursiveFunction = DeepRecursiveFunction { t -> if (t == null) 0 else even.callRecursive(t.left) + even.callRecursive(t.right) } } println("=== deepTree($n)") val dn = depth(deepTree) println("depth = $dn") println("even = ${mutualRecursion.even(deepTree)}") println(" odd = ${mutualRecursion.odd(deepTree)}") check(dn == n) println("=== binaryTree($k)") val dk = depth(binaryTree) println("depth = $dk") println("even = ${mutualRecursion.even(binaryTree)}") println(" odd = ${mutualRecursion.odd(binaryTree)}") check(dk == k) println("================") // Mix of call & callRecursive val mix = object { val b: DeepRecursiveFunction = DeepRecursiveFunction { i -> "b$i" } val a: DeepRecursiveFunction = DeepRecursiveFunction { i -> when (i) { 0 -> b.callRecursive(1) + callRecursive(2) + a().callRecursive(3) else -> "a$i" } } fun a() = a } val s = mix.a.invoke(0) println("s = $s") check(s == "b1a2a3") // Mutually recursive tail calls & broken equals val tailRecs = object { var nullCount = 0 val a: DeepRecursiveFunction = DeepRecursiveFunction { t -> if (t == null) BadEquals(nullCount++) else b.callRecursive(t.left) } val b: DeepRecursiveFunction = DeepRecursiveFunction { t -> if (t == null) BadEquals(nullCount++) else a.callRecursive(t.left) } } println("tailRecs = ${tailRecs.a.invoke(deepTree)}") check(tailRecs.nullCount == 1) } // It is equals to any other class private class BadEquals(val index: Int) { override fun equals(other: Any?): Boolean = true override fun toString(): String = "OK" }