Skip to content

Instantly share code, notes, and snippets.

@elizarov
Last active March 25, 2024 00:40
Show Gist options
  • Select an option

  • Save elizarov/861dda8c3e8c5ee36eaa6db4ad996568 to your computer and use it in GitHub Desktop.

Select an option

Save elizarov/861dda8c3e8c5ee36eaa6db4ad996568 to your computer and use it in GitHub Desktop.

Revisions

  1. elizarov revised this gist Apr 25, 2020. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions DeepRecursiveFunction.kt
    Original file line number Diff line number Diff line change
    @@ -14,11 +14,11 @@ import kotlin.coroutines.intrinsics.*
    * in this scope with `callRecursive` extension, too.
    *
    * For example, take a look at the following recursive tree class and a deeply
    * recursive instance of this tree with a million nodes:
    * recursive instance of this tree with 100K nodes:
    *
    * ```
    * class Tree(val left: Tree? = null, val right: Tree? = null)
    * val deepTree = generateSequence(Tree()) { Tree(it) }.take(1_000_000).last()
    * val deepTree = generateSequence(Tree()) { Tree(it) }.take(100_000).last()
    * ```
    *
    * A regular recursive function can be defined to compute a depth of a tree:
  2. elizarov revised this gist Apr 25, 2020. 1 changed file with 0 additions and 3 deletions.
    3 changes: 0 additions & 3 deletions DeepRecursiveFunction.kt
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,3 @@
    package rec


    import kotlin.coroutines.*
    import kotlin.coroutines.intrinsics.*

  3. elizarov renamed this gist Apr 25, 2020. 1 changed file with 12 additions and 10 deletions.
    22 changes: 12 additions & 10 deletions RecursiveFunction.kt → DeepRecursiveFunction.kt
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,8 @@
    package rec


    import kotlin.coroutines.*
    import kotlin.coroutines.intrinsics.*
    import kotlin.math.*

    /**
    * Defines deep recursive function that keeps its stack on the heap,
    @@ -98,8 +100,8 @@ public sealed class DeepRecursiveScope<T, R> {
    @Deprecated(
    level = DeprecationLevel.ERROR,
    message =
    "'invoke' should not be called from DeepRecursiveScope. " +
    "Use 'callRecursive' to do recursion in the heap instead of the call stack.",
    "'invoke' should not be called from DeepRecursiveScope. " +
    "Use 'callRecursive' to do recursion in the heap instead of the call stack.",
    replaceWith = ReplaceWith("this.callRecursive(value)")
    )
    public operator fun DeepRecursiveFunction<*, *>.invoke(value: Any?): Nothing =
    @@ -216,11 +218,11 @@ fun main() {
    val binaryTree = binaryTree(k)

    // Regular code: Stack overflow error
    // fun depth(t: Tree?): Int =
    // if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1
    // fun depth(t: Tree?): Int =
    // if (t == null) 0 else maxOf(depth(t.left), depth(t.right)) + 1

    val depth = DeepRecursiveFunction<Tree?, Int> { t ->
    if (t == null) 0 else max(
    if (t == null) 0 else maxOf(
    callRecursive(t.left),
    callRecursive(t.right)
    ) + 1
    @@ -258,11 +260,11 @@ fun main() {
    val mix = object {
    val b: DeepRecursiveFunction<Int, String> = DeepRecursiveFunction { i -> "b$i" }
    val a: DeepRecursiveFunction<Int, String> = DeepRecursiveFunction { i ->
    when (i) {
    0 -> b.callRecursive(1) + callRecursive(2) + a().callRecursive(3)
    else -> "a$i"
    }
    when (i) {
    0 -> b.callRecursive(1) + callRecursive(2) + a().callRecursive(3)
    else -> "a$i"
    }
    }
    fun a() = a
    }
    val s = mix.a.invoke(0)
  4. elizarov revised this gist Apr 24, 2020. 1 changed file with 41 additions and 39 deletions.
    80 changes: 41 additions & 39 deletions RecursiveFunction.kt
    Original file line number Diff line number Diff line change
    @@ -5,13 +5,14 @@ import kotlin.math.*
    /**
    * Defines deep recursive function that keeps its stack on the heap,
    * which allows very deep recursive computations that do not use the actual call stack.
    * To initiate a call to this deep recursive function use [initiateCall] function.
    * To initiate a call to this deep recursive function use its [invoke] function.
    * As a rule of thumb, it should be used if recursion goes deeper than a thousand calls.
    *
    * The [DeepRecursiveFunction] takes one parameter of type [T] and returns a result of type [R].
    * The [block] of code defines the body of a recursive function. In this block
    * [callRecursive][DeepRecursiveScope.callRecursive] function can be used to make a recursive call
    * to the declared function. Other instances of [DeepRecursiveFunction] can be invoked
    * in this scope, too.
    * to the declared function. Other instances of [DeepRecursiveFunction] can be called
    * in this scope with `callRecursive` extension, too.
    *
    * For example, take a look at the following recursive tree class and a deeply
    * recursive instance of this tree with a million nodes:
    @@ -30,28 +31,27 @@ import kotlin.math.*
    * ```
    *
    * If this `depth` function is called for a `deepTree` it produces [StackOverflowError] because of deep recursion.
    * However, the `depth` function can be rewritten using `DeepRecursiveFunction` in the following way, then
    * it successfully computes [`depth.initiateCall(deepTree)`][DeepRecursiveFunction.initiateCall] expression:
    * However, the `depth` function can be rewritten using `DeepRecursiveFunction` in the following way, and then
    * it successfully computes [`depth(deepTree)`][DeepRecursiveFunction.invoke] expression:
    *
    * ```
    * val depth = DeepRecursiveFunction<Tree?, Int> { t ->
    * if (t == null) 0 else max(callRecursive(t.left), callRecursive(t.right)) + 1
    * }
    * println(depth.initiateCall(deepTree)) // Ok
    * println(depth(deepTree)) // Ok
    * ```
    *
    * Deep recursive functions can also mutually call each other using a heap for the stack. The invocation of
    * [DeepRecursiveFunction] in [DeepRecursiveScope] automatically resolves properly via
    * [DeepRecursiveScope.invoke] operator. For example, the
    * Deep recursive functions can also mutually call each other using a heap for the stack via
    * [callRecursive][DeepRecursiveScope.callRecursive] extension. For example, the
    * following pair of mutually recursive functions computes the number of tree nodes at even depth in the tree.
    *
    * ```
    * val mutualRecursion = object {
    * val even: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
    * if (t == null) 0 else odd(t.left) + odd(t.right) + 1
    * if (t == null) 0 else odd.callRecursive(t.left) + odd.callRecursive(t.right) + 1
    * }
    * val odd: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
    * if (t == null) 0 else even(t.left) + even(t.right)
    * if (t == null) 0 else even.callRecursive(t.left) + even.callRecursive(t.right)
    * }
    * }
    * ```
    @@ -67,17 +67,16 @@ public class DeepRecursiveFunction<T, R>(
    /**
    * Initiates a call to this deep recursive function, forming a root of the call tree.
    *
    * This method should not be used from inside of [DeepRecursiveScope] as it uses the call stack slot for
    * initial recursive invocation. From inside of [DeepRecursiveScope] use either:
    * * [DeepRecursiveScope.callRecursive] to recursively call the current function or
    * * [DeepRecursiveScope.invoke] to call another function for a mutually recursive call.
    * This operator should not be used from inside of [DeepRecursiveScope] as it uses the call stack slot for
    * initial recursive invocation. From inside of [DeepRecursiveScope] use
    * [callRecursive][DeepRecursiveScope.callRecursive].
    */
    public fun <T, R> DeepRecursiveFunction<T, R>.initiateCall(value: T): R =
    public operator fun <T, R> DeepRecursiveFunction<T, R>.invoke(value: T): R =
    DeepRecursiveScopeImpl<T, R>(block, value).runCallLoop()

    /**
    * A scope class for [DeepRecursiveFunction] function declaration that defines [callRecursive] function and
    * provides ability to [invoke] any [DeepRecursiveFunction] putting the call activation frame on the heap.
    * A scope class for [DeepRecursiveFunction] function declaration that defines [callRecursive] methods to
    * recursively call this function or another [DeepRecursiveFunction] putting the call activation frame on the heap.
    *
    * @param [T] function parameter type.
    * @param [R] function result type.
    @@ -94,16 +93,16 @@ public sealed class DeepRecursiveScope<T, R> {
    * Makes call to the specified [DeepRecursiveFunction] function putting the call activation frame on the heap,
    * as opposed to the actual call stack that is used by a regular call.
    */
    public abstract suspend operator fun <U, S> DeepRecursiveFunction<U, S>.invoke(value: U): S
    public abstract suspend fun <U, S> DeepRecursiveFunction<U, S>.callRecursive(value: U): S

    @Deprecated(
    level = DeprecationLevel.ERROR,
    message =
    "'initiateCall' should not be called from DeepRecursiveScope. " +
    "Use 'invoke' to for the recursion in the heap instead of the call stack.",
    replaceWith = ReplaceWith("this(value)")
    "'invoke' should not be called from DeepRecursiveScope. " +
    "Use 'callRecursive' to do recursion in the heap instead of the call stack.",
    replaceWith = ReplaceWith("this.callRecursive(value)")
    )
    public fun DeepRecursiveFunction<*, *>.initiateCall(value: Any?): Nothing =
    public operator fun DeepRecursiveFunction<*, *>.invoke(value: Any?): Nothing =
    throw UnsupportedOperationException("Should not be called from DeepRecursiveScope")
    }

    @@ -145,7 +144,7 @@ private class DeepRecursiveScopeImpl<T, R>(
    COROUTINE_SUSPENDED
    }

    override suspend fun <U, S> DeepRecursiveFunction<U, S>.invoke(value: U): S = suspendCoroutineUninterceptedOrReturn { cont ->
    override suspend fun <U, S> DeepRecursiveFunction<U, S>.callRecursive(value: U): S = suspendCoroutineUninterceptedOrReturn { cont ->
    // calling another recursive function
    val function = block as DeepRecursiveFunctionBlock
    with(this@DeepRecursiveScopeImpl) {
    @@ -221,34 +220,37 @@ fun main() {
    // if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1

    val depth = DeepRecursiveFunction<Tree?, Int> { t ->
    if (t == null) 0 else max(callRecursive(t.left), callRecursive(t.right)) + 1
    if (t == null) 0 else max(
    callRecursive(t.left),
    callRecursive(t.right)
    ) + 1
    }

    // A pair of mutually-recursive functions
    val mutualRecursion = object {
    val even: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
    if (t == null) 0 else odd(t.left) + odd(t.right) + 1
    if (t == null) 0 else odd.callRecursive(t.left) + odd.callRecursive(t.right) + 1
    }

    val odd: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
    if (t == null) 0 else even(t.left) + even(t.right)
    if (t == null) 0 else even.callRecursive(t.left) + even.callRecursive(t.right)
    }
    }

    println("=== deepTree($n)")

    val dn = depth.initiateCall(deepTree)
    val dn = depth(deepTree)
    println("depth = $dn")
    println("even = ${mutualRecursion.even.initiateCall(deepTree)}")
    println(" odd = ${mutualRecursion.odd.initiateCall(deepTree)}")
    println("even = ${mutualRecursion.even(deepTree)}")
    println(" odd = ${mutualRecursion.odd(deepTree)}")
    check(dn == n)

    println("=== binaryTree($k)")

    val dk = depth.initiateCall(binaryTree)
    val dk = depth(binaryTree)
    println("depth = $dk")
    println("even = ${mutualRecursion.even.initiateCall(binaryTree)}")
    println(" odd = ${mutualRecursion.odd.initiateCall(binaryTree)}")
    println("even = ${mutualRecursion.even(binaryTree)}")
    println(" odd = ${mutualRecursion.odd(binaryTree)}")
    check(dk == k)

    println("================")
    @@ -257,27 +259,27 @@ fun main() {
    val b: DeepRecursiveFunction<Int, String> = DeepRecursiveFunction { i -> "b$i" }
    val a: DeepRecursiveFunction<Int, String> = DeepRecursiveFunction { i ->
    when (i) {
    0 -> b(1) + callRecursive(2) + aa()(3)
    0 -> b.callRecursive(1) + callRecursive(2) + a().callRecursive(3)
    else -> "a$i"
    }
    }
    fun aa() = a
    fun a() = a
    }
    val s = mix.a.initiateCall(0)
    val s = mix.a.invoke(0)
    println("s = $s")
    check(s == "b1a2a3")

    // Mutually recursive tail calls & broken equals
    val tailRecs = object {
    var nullCount = 0
    val a: DeepRecursiveFunction<Tree?, BadEquals> = DeepRecursiveFunction { t ->
    if (t == null) BadEquals(nullCount++) else b(t.left)
    if (t == null) BadEquals(nullCount++) else b.callRecursive(t.left)
    }
    val b: DeepRecursiveFunction<Tree?, BadEquals> = DeepRecursiveFunction { t ->
    if (t == null) BadEquals(nullCount++) else a(t.left)
    if (t == null) BadEquals(nullCount++) else a.callRecursive(t.left)
    }
    }
    println("tailRecs = ${tailRecs.a.initiateCall(deepTree)}")
    println("tailRecs = ${tailRecs.a.invoke(deepTree)}")
    check(tailRecs.nullCount == 1)
    }

  5. elizarov revised this gist Apr 24, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion RecursiveFunction.kt
    Original file line number Diff line number Diff line change
    @@ -60,7 +60,7 @@ import kotlin.math.*
    * @param [R] the function result type.
    * @param block the function body.
    */
    public inline class DeepRecursiveFunction<T, R>(
    public class DeepRecursiveFunction<T, R>(
    internal val block: suspend DeepRecursiveScope<T, R>.(T) -> R
    )

  6. elizarov revised this gist Apr 24, 2020. 1 changed file with 54 additions and 48 deletions.
    102 changes: 54 additions & 48 deletions RecursiveFunction.kt
    Original file line number Diff line number Diff line change
    @@ -5,8 +5,9 @@ import kotlin.math.*
    /**
    * Defines deep recursive function that keeps its stack on the heap,
    * which allows very deep recursive computations that do not use the actual call stack.
    * To initiate a call to this deep recursive function use [initiateCall] function.
    *
    * The function takes one parameter of type [T] and returns a result of type [R].
    * The [DeepRecursiveFunction] takes one parameter of type [T] and returns a result of type [R].
    * The [block] of code defines the body of a recursive function. In this block
    * [callRecursive][DeepRecursiveScope.callRecursive] function can be used to make a recursive call
    * to the declared function. Other instances of [DeepRecursiveFunction] can be invoked
    @@ -25,59 +26,57 @@ import kotlin.math.*
    * ```
    * fun depth(t: Tree?): Int =
    * if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1
    * println(depth(deepTree)) // StackOverflowError
    * ```
    *
    * If this `depth` function is called for a `deepTree` it produces [StackOverflowError] because
    * of deep recursion unless a very large stack size is used.
    *
    * However, the `depth` function can be rewritten using `deepRecursive` function in the following way so that
    * it successfully computes `depth(deepTree)`:
    * If this `depth` function is called for a `deepTree` it produces [StackOverflowError] because of deep recursion.
    * However, the `depth` function can be rewritten using `DeepRecursiveFunction` in the following way, then
    * it successfully computes [`depth.initiateCall(deepTree)`][DeepRecursiveFunction.initiateCall] expression:
    *
    * ```
    * val depth = deepRecursive<Tree?, Int> { t ->
    * val depth = DeepRecursiveFunction<Tree?, Int> { t ->
    * if (t == null) 0 else max(callRecursive(t.left), callRecursive(t.right)) + 1
    * }
    * println(depth.initiateCall(deepTree)) // Ok
    * ```
    *
    * Deep recursive functions can also mutually call each other using a heap for the stack. The invocation of
    * [DeepRecursiveFunction] in [DeepRecursiveScope] automatically resolves properly. For example, the
    * [DeepRecursiveFunction] in [DeepRecursiveScope] automatically resolves properly via
    * [DeepRecursiveScope.invoke] operator. For example, the
    * following pair of mutually recursive functions computes the number of tree nodes at even depth in the tree.
    *
    * ```
    * val mutualRecursion = object {
    * val even: DeepRecursiveFunction<Tree?, Int> = deepRecursive { t ->
    * val even: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
    * if (t == null) 0 else odd(t.left) + odd(t.right) + 1
    * }
    * val odd: DeepRecursiveFunction<Tree?, Int> = deepRecursive { t ->
    * val odd: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
    * if (t == null) 0 else even(t.left) + even(t.right)
    * }
    * }
    * ```
    *
    * @param [T] function parameter type.
    * @param [R] function result type.
    * @param block function body.
    */
    public fun <T, R> deepRecursive(block: suspend DeepRecursiveScope<T, R>.(T) -> R): DeepRecursiveFunction<T, R> =
    DeepRecursiveFunction(block)

    /**
    * An instance of deep recursive function that is created by [deepRecursive] builder.
    * To initiate a call to this deep recursive function use operator [invoke] extension.
    * @param [T] the function parameter type.
    * @param [R] the function result type.
    * @param block the function body.
    */
    public class DeepRecursiveFunction<T, R> internal constructor(
    public inline class DeepRecursiveFunction<T, R>(
    internal val block: suspend DeepRecursiveScope<T, R>.(T) -> R
    ) {
    internal fun callImpl(value: T): R = DeepRecursiveScopeImpl<T, R>(block, value).runCallLoop()
    }
    )

    /**
    * Initiates a call to this deep recursive function, forming a root of the call tree.
    *
    * This method should not be used from inside of [DeepRecursiveScope] as it uses the call stack slot for
    * initial recursive invocation. From inside of [DeepRecursiveScope] use either:
    * * [DeepRecursiveScope.callRecursive] to recursively call the current function or
    * * [DeepRecursiveScope.invoke] to call another function for a mutually recursive call.
    */
    public operator fun <T, R> DeepRecursiveFunction<T, R>.invoke(value: T): R = callImpl(value)
    public fun <T, R> DeepRecursiveFunction<T, R>.initiateCall(value: T): R =
    DeepRecursiveScopeImpl<T, R>(block, value).runCallLoop()

    /**
    * A scope class for [deepRecursive] function declaration that defines [callRecursive] function and
    * A scope class for [DeepRecursiveFunction] function declaration that defines [callRecursive] function and
    * provides ability to [invoke] any [DeepRecursiveFunction] putting the call activation frame on the heap.
    *
    * @param [T] function parameter type.
    @@ -86,7 +85,7 @@ public operator fun <T, R> DeepRecursiveFunction<T, R>.invoke(value: T): R = cal
    @RestrictsSuspension
    public sealed class DeepRecursiveScope<T, R> {
    /**
    * Makes recursive call to this [deepRecursive] function putting the call activation frame on the heap,
    * Makes recursive call to this [DeepRecursiveFunction] function putting the call activation frame on the heap,
    * as opposed to the actual call stack that is used by a regular recursive call.
    */
    public abstract suspend fun callRecursive(value: T): R
    @@ -96,6 +95,16 @@ public sealed class DeepRecursiveScope<T, R> {
    * as opposed to the actual call stack that is used by a regular call.
    */
    public abstract suspend operator fun <U, S> DeepRecursiveFunction<U, S>.invoke(value: U): S

    @Deprecated(
    level = DeprecationLevel.ERROR,
    message =
    "'initiateCall' should not be called from DeepRecursiveScope. " +
    "Use 'invoke' to for the recursion in the heap instead of the call stack.",
    replaceWith = ReplaceWith("this(value)")
    )
    public fun DeepRecursiveFunction<*, *>.initiateCall(value: Any?): Nothing =
    throw UnsupportedOperationException("Should not be called from DeepRecursiveScope")
    }

    // ================== Implementation ==================
    @@ -211,72 +220,69 @@ fun main() {
    // fun depth(t: Tree?): Int =
    // if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1

    val depth = deepRecursive<Tree?, Int> { t ->
    val depth = DeepRecursiveFunction<Tree?, Int> { t ->
    if (t == null) 0 else max(callRecursive(t.left), callRecursive(t.right)) + 1
    }

    // A pair of mutually-recursive functions
    val mutualRecursion = object {
    val even: DeepRecursiveFunction<Tree?, Int> = deepRecursive { t ->
    val even: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
    if (t == null) 0 else odd(t.left) + odd(t.right) + 1
    }

    val odd: DeepRecursiveFunction<Tree?, Int> = deepRecursive { t ->
    val odd: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
    if (t == null) 0 else even(t.left) + even(t.right)
    }
    }

    println("=== deepTree($n)")

    val dn = depth(deepTree)
    val dn = depth.initiateCall(deepTree)
    println("depth = $dn")
    println("even = ${mutualRecursion.even(deepTree)}")
    println(" odd = ${mutualRecursion.odd(deepTree)}")
    println("even = ${mutualRecursion.even.initiateCall(deepTree)}")
    println(" odd = ${mutualRecursion.odd.initiateCall(deepTree)}")
    check(dn == n)

    println("=== binaryTree($k)")

    val dk = depth(binaryTree)
    val dk = depth.initiateCall(binaryTree)
    println("depth = $dk")
    println("even = ${mutualRecursion.even(binaryTree)}")
    println(" odd = ${mutualRecursion.odd(binaryTree)}")
    println("even = ${mutualRecursion.even.initiateCall(binaryTree)}")
    println(" odd = ${mutualRecursion.odd.initiateCall(binaryTree)}")
    check(dk == k)

    println("================")
    // Mix of call & callRecursive
    val mix = object {
    val b: DeepRecursiveFunction<Int, String> = deepRecursive { i -> "b$i" }
    lateinit var a: DeepRecursiveFunction<Int, String>

    init {
    a = deepRecursive { i ->
    val b: DeepRecursiveFunction<Int, String> = DeepRecursiveFunction { i -> "b$i" }
    val a: DeepRecursiveFunction<Int, String> = DeepRecursiveFunction { i ->
    when (i) {
    0 -> b(1) + callRecursive(2) + a(3)
    0 -> b(1) + callRecursive(2) + aa()(3)
    else -> "a$i"
    }
    }
    }
    fun aa() = a
    }
    val s = mix.a(0)
    val s = mix.a.initiateCall(0)
    println("s = $s")
    check(s == "b1a2a3")

    // Mutually recursive tail calls & broken equals
    val tailRecs = object {
    var nullCount = 0
    val a: DeepRecursiveFunction<Tree?, BadEquals> = deepRecursive { t ->
    val a: DeepRecursiveFunction<Tree?, BadEquals> = DeepRecursiveFunction { t ->
    if (t == null) BadEquals(nullCount++) else b(t.left)
    }
    val b: DeepRecursiveFunction<Tree?, BadEquals> = deepRecursive { t ->
    val b: DeepRecursiveFunction<Tree?, BadEquals> = DeepRecursiveFunction { t ->
    if (t == null) BadEquals(nullCount++) else a(t.left)
    }
    }
    println("tailRecs = ${tailRecs.a(deepTree)}")
    println("tailRecs = ${tailRecs.a.initiateCall(deepTree)}")
    check(tailRecs.nullCount == 1)
    }

    // It is equals to any other class
    private class BadEquals(val index: Int) {
    override fun equals(other: Any?): Boolean = true
    override fun toString(): String = "OK"
    }
    }
  7. elizarov revised this gist Nov 23, 2019. 1 changed file with 65 additions and 20 deletions.
    85 changes: 65 additions & 20 deletions RecursiveFunction.kt
    Original file line number Diff line number Diff line change
    @@ -102,6 +102,8 @@ public sealed class DeepRecursiveScope<T, R> {

    private typealias DeepRecursiveFunctionBlock = Function3<Any?, Any?, Continuation<Any?>?, Any?>

    private val UNDEFINED_RESULT = Result.success(COROUTINE_SUSPENDED)

    @Suppress("UNCHECKED_CAST")
    private class DeepRecursiveScopeImpl<T, R>(
    block: suspend DeepRecursiveScope<T, R>.(T) -> R,
    @@ -117,7 +119,7 @@ private class DeepRecursiveScopeImpl<T, R>(
    private var cont: Continuation<Any?>? = this as Continuation<Any?>

    // Completion result (completion of the whole call stack)
    private var result: Result<Any?> = Result.success(null)
    private var result: Result<Any?> = UNDEFINED_RESULT

    override val context: CoroutineContext
    get() = EmptyCoroutineContext
    @@ -142,10 +144,7 @@ private class DeepRecursiveScopeImpl<T, R>(
    if (function !== currentFunction) {
    // calling a different function -- create a trampoline to restore function ref
    this.function = function
    this.cont = Continuation(EmptyCoroutineContext) {
    this.function = currentFunction
    (cont as Continuation<Any?>).resumeWith(it)
    }
    this.cont = crossFunctionCompletion(currentFunction, cont as Continuation<Any?>)
    } else {
    // calling the same function -- direct
    this.cont = cont as Continuation<Any?>
    @@ -155,20 +154,42 @@ private class DeepRecursiveScopeImpl<T, R>(
    COROUTINE_SUSPENDED
    }

    private fun crossFunctionCompletion(
    currentFunction: DeepRecursiveFunctionBlock,
    cont: Continuation<Any?>
    ): Continuation<Any?> = Continuation(EmptyCoroutineContext) {
    this.function = currentFunction
    // When going back from a trampoline we cannot just call cont.resume (stack usage!)
    // We delegate the cont.resumeWith(it) call to runCallLoop
    this.cont = cont
    this.result = it
    }

    @Suppress("UNCHECKED_CAST")
    fun runCallLoop(): R {
    while (true) {
    // Note: cont is set to null in DeepRecursiveScopeImpl.resumeWith when the whole computation completes
    val result = this.result
    val cont = this.cont
    ?: return (result as Result<R>).getOrThrow()
    // This is block.startCoroutine(this, value, cont)
    val r = try {
    function(this, value, cont)
    } catch (e: Throwable) {
    cont.resumeWithException(e)
    continue
    ?: return (result as Result<R>).getOrThrow() // done -- final result
    // The order of comparison is important here for that case of rogue class with broken equals
    if (UNDEFINED_RESULT == result) {
    // call "function" with "value" using "cont" as completion
    val r = try {
    // This is block.startCoroutine(this, value, cont)
    function(this, value, cont)
    } catch (e: Throwable) {
    cont.resumeWithException(e)
    continue
    }
    // If the function returns without suspension -- calls its continuation immediately
    if (r !== COROUTINE_SUSPENDED)
    cont.resume(r as R)
    } else {
    // we returned from a crossFunctionCompletion trampoline -- call resume here
    this.result = UNDEFINED_RESULT // reset result back
    cont.resumeWith(result)
    }
    if (r !== COROUTINE_SUSPENDED)
    cont.resume(r as R)
    }
    }
    }
    @@ -178,6 +199,14 @@ private class DeepRecursiveScopeImpl<T, R>(
    class Tree(val left: Tree? = null, val right: Tree? = null)

    fun main() {
    val n = 1_000_000
    val deepTree = generateSequence(Tree()) { Tree(it) }.take(n).last()

    fun binaryTree(k: Int): Tree? =
    if (k == 0) null else Tree(binaryTree(k - 1), binaryTree(k - 1))
    val k = 20
    val binaryTree = binaryTree(k)

    // Regular code: Stack overflow error
    // fun depth(t: Tree?): Int =
    // if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1
    @@ -197,26 +226,23 @@ fun main() {
    }
    }

    val n = 1_000_000
    println("=== deepTree($n)")
    val deepTree = generateSequence(Tree()) { Tree(it) }.take(n).last()

    val dn = depth(deepTree)
    println("depth = $dn")
    println("even = ${mutualRecursion.even(deepTree)}")
    println(" odd = ${mutualRecursion.odd(deepTree)}")
    check(dn == n)

    fun binaryTree(k: Int): Tree? =
    if (k == 0) null else Tree(binaryTree(k - 1), binaryTree(k - 1))
    val k = 20
    println("=== binaryTree($k)")
    val binaryTree = binaryTree(k)

    val dk = depth(binaryTree)
    println("depth = $dk")
    println("even = ${mutualRecursion.even(binaryTree)}")
    println(" odd = ${mutualRecursion.odd(binaryTree)}")
    check(dk == k)

    println("================")
    // Mix of call & callRecursive
    val mix = object {
    val b: DeepRecursiveFunction<Int, String> = deepRecursive { i -> "b$i" }
    @@ -234,4 +260,23 @@ fun main() {
    val s = mix.a(0)
    println("s = $s")
    check(s == "b1a2a3")

    // Mutually recursive tail calls & broken equals
    val tailRecs = object {
    var nullCount = 0
    val a: DeepRecursiveFunction<Tree?, BadEquals> = deepRecursive { t ->
    if (t == null) BadEquals(nullCount++) else b(t.left)
    }
    val b: DeepRecursiveFunction<Tree?, BadEquals> = deepRecursive { t ->
    if (t == null) BadEquals(nullCount++) else a(t.left)
    }
    }
    println("tailRecs = ${tailRecs.a(deepTree)}")
    check(tailRecs.nullCount == 1)
    }

    // It is equals to any other class
    private class BadEquals(val index: Int) {
    override fun equals(other: Any?): Boolean = true
    override fun toString(): String = "OK"
    }
  8. elizarov revised this gist Nov 22, 2019. 1 changed file with 147 additions and 34 deletions.
    181 changes: 147 additions & 34 deletions RecursiveFunction.kt
    Original file line number Diff line number Diff line change
    @@ -3,13 +3,14 @@ import kotlin.coroutines.intrinsics.*
    import kotlin.math.*

    /**
    * Defines recursive function that keeps its stack on the heap,
    * Defines deep recursive function that keeps its stack on the heap,
    * which allows very deep recursive computations that do not use the actual call stack.
    *
    * The function takes one parameter of type [T] and returns a result of type [R].
    * The [block] of code defines the body of a recursive function. In this block
    * [rec][RecursiveFunctionScope.rec] function should be used to make a recursive call
    * to the declared function.
    * [callRecursive][DeepRecursiveScope.callRecursive] function can be used to make a recursive call
    * to the declared function. Other instances of [DeepRecursiveFunction] can be invoked
    * in this scope, too.
    *
    * For example, take a look at the following recursive tree class and a deeply
    * recursive instance of this tree with a million nodes:
    @@ -26,48 +27,97 @@ import kotlin.math.*
    * if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1
    * ```
    *
    * If this `depth` function is called for a `deepTree` it will produce [StackOverflowError] because
    * If this `depth` function is called for a `deepTree` it produces [StackOverflowError] because
    * of deep recursion unless a very large stack size is used.
    *
    * However, the `depth` function can be rewritten using `recursiveFunction` in the following way so that
    * However, the `depth` function can be rewritten using `deepRecursive` function in the following way so that
    * it successfully computes `depth(deepTree)`:
    *
    * ```
    * val depth = recursiveFunction<Tree?, Int> { t ->
    * if (t == null) 0 else max(rec(t.left), rec(t.right)) + 1
    * val depth = deepRecursive<Tree?, Int> { t ->
    * if (t == null) 0 else max(callRecursive(t.left), callRecursive(t.right)) + 1
    * }
    * ```
    *
    * Deep recursive functions can also mutually call each other using a heap for the stack. The invocation of
    * [DeepRecursiveFunction] in [DeepRecursiveScope] automatically resolves properly. For example, the
    * following pair of mutually recursive functions computes the number of tree nodes at even depth in the tree.
    *
    * ```
    * val mutualRecursion = object {
    * val even: DeepRecursiveFunction<Tree?, Int> = deepRecursive { t ->
    * if (t == null) 0 else odd(t.left) + odd(t.right) + 1
    * }
    * val odd: DeepRecursiveFunction<Tree?, Int> = deepRecursive { t ->
    * if (t == null) 0 else even(t.left) + even(t.right)
    * }
    * }
    * ```
    *
    * @param [T] function parameter type.
    * @param [R] function result type.
    * @param block function body.
    */
    fun <T, R> recursiveFunction(block: suspend RecursiveFunctionScope<T, R>.(T) -> R): (T) -> R =
    { value ->
    RecursiveFunctionImpl(block, value).call()
    }
    public fun <T, R> deepRecursive(block: suspend DeepRecursiveScope<T, R>.(T) -> R): DeepRecursiveFunction<T, R> =
    DeepRecursiveFunction(block)

    /**
    * A scope class for [recursiveFunction] declaration that defines [rec] function.
    *
    * An instance of deep recursive function that is created by [deepRecursive] builder.
    * To initiate a call to this deep recursive function use operator [invoke] extension.
    */
    public class DeepRecursiveFunction<T, R> internal constructor(
    internal val block: suspend DeepRecursiveScope<T, R>.(T) -> R
    ) {
    internal fun callImpl(value: T): R = DeepRecursiveScopeImpl<T, R>(block, value).runCallLoop()
    }

    /**
    * Initiates a call to this deep recursive function, forming a root of the call tree.
    */
    public operator fun <T, R> DeepRecursiveFunction<T, R>.invoke(value: T): R = callImpl(value)

    /**
    * A scope class for [deepRecursive] function declaration that defines [callRecursive] function and
    * provides ability to [invoke] any [DeepRecursiveFunction] putting the call activation frame on the heap.
    *
    * @param [T] function parameter type.
    * @param [R] function result type.
    */
    @RestrictsSuspension
    abstract class RecursiveFunctionScope<T, R> {
    public sealed class DeepRecursiveScope<T, R> {
    /**
    * Makes recursive call to this [recursiveFunction] putting the call activation frame on the heap,
    * Makes recursive call to this [deepRecursive] function putting the call activation frame on the heap,
    * as opposed to the actual call stack that is used by a regular recursive call.
    */
    abstract suspend fun rec(value: T): R
    public abstract suspend fun callRecursive(value: T): R

    /**
    * Makes call to the specified [DeepRecursiveFunction] function putting the call activation frame on the heap,
    * as opposed to the actual call stack that is used by a regular call.
    */
    public abstract suspend operator fun <U, S> DeepRecursiveFunction<U, S>.invoke(value: U): S
    }

    private class RecursiveFunctionImpl<T, R>(
    private val block: suspend RecursiveFunctionScope<T, R>.(T) -> R,
    private var value: T? = null
    ) : RecursiveFunctionScope<T, R>(), Continuation<R> {
    // ================== Implementation ==================

    private typealias DeepRecursiveFunctionBlock = Function3<Any?, Any?, Continuation<Any?>?, Any?>

    @Suppress("UNCHECKED_CAST")
    private class DeepRecursiveScopeImpl<T, R>(
    block: suspend DeepRecursiveScope<T, R>.(T) -> R,
    value: T
    ) : DeepRecursiveScope<T, R>(), Continuation<R> {
    // Active function block
    private var function: DeepRecursiveFunctionBlock = block as DeepRecursiveFunctionBlock

    // Value to call function with
    private var value: Any? = value

    // Continuation of the current call
    private var cont: Continuation<Any?>? = this as Continuation<Any?>

    // Completion result (completion of the whole call stack)
    private var result: Result<Any?> = Result.success(null)
    private var cont: Continuation<R>? = this

    override val context: CoroutineContext
    get() = EmptyCoroutineContext
    @@ -77,21 +127,42 @@ private class RecursiveFunctionImpl<T, R>(
    this.result = result
    }

    override suspend fun rec(value: T): R = suspendCoroutineUninterceptedOrReturn { cont ->
    this.cont = cont
    override suspend fun callRecursive(value: T): R = suspendCoroutineUninterceptedOrReturn { cont ->
    // calling the same function that is currently active
    this.cont = cont as Continuation<Any?>
    this.value = value
    COROUTINE_SUSPENDED
    }

    override suspend fun <U, S> DeepRecursiveFunction<U, S>.invoke(value: U): S = suspendCoroutineUninterceptedOrReturn { cont ->
    // calling another recursive function
    val function = block as DeepRecursiveFunctionBlock
    with(this@DeepRecursiveScopeImpl) {
    val currentFunction = this.function
    if (function !== currentFunction) {
    // calling a different function -- create a trampoline to restore function ref
    this.function = function
    this.cont = Continuation(EmptyCoroutineContext) {
    this.function = currentFunction
    (cont as Continuation<Any?>).resumeWith(it)
    }
    } else {
    // calling the same function -- direct
    this.cont = cont as Continuation<Any?>
    }
    this.value = value
    }
    COROUTINE_SUSPENDED
    }

    @Suppress("UNCHECKED_CAST")
    fun call(): R {
    val f = block as Function3<Any?, T?, Continuation<R>?, Any?>
    fun runCallLoop(): R {
    while (true) {
    val cont = this.cont
    ?: return (result as Result<R>).getOrThrow()
    // This is block.startCoroutine(this, value, cont)
    val r = try {
    f(this, value, cont)
    function(this, value, cont)
    } catch (e: Throwable) {
    cont.resumeWithException(e)
    continue
    @@ -107,18 +178,60 @@ private class RecursiveFunctionImpl<T, R>(
    class Tree(val left: Tree? = null, val right: Tree? = null)

    fun main() {
    val n = 1_000_000
    val deepTree = generateSequence(Tree()) { Tree(it) }.take(n).last()

    // Regular code: Stack overflow error
    // fun depth(t: Tree?): Int =
    // if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1

    val depth = recursiveFunction<Tree?, Int> { t ->
    if (t == null) 0 else max(rec(t.left), rec(t.right)) + 1
    val depth = deepRecursive<Tree?, Int> { t ->
    if (t == null) 0 else max(callRecursive(t.left), callRecursive(t.right)) + 1
    }

    // A pair of mutually-recursive functions
    val mutualRecursion = object {
    val even: DeepRecursiveFunction<Tree?, Int> = deepRecursive { t ->
    if (t == null) 0 else odd(t.left) + odd(t.right) + 1
    }

    val odd: DeepRecursiveFunction<Tree?, Int> = deepRecursive { t ->
    if (t == null) 0 else even(t.left) + even(t.right)
    }
    }

    val d = depth(deepTree)
    println("depth = $d")
    check(d == n)
    val n = 1_000_000
    println("=== deepTree($n)")
    val deepTree = generateSequence(Tree()) { Tree(it) }.take(n).last()

    val dn = depth(deepTree)
    println("depth = $dn")
    println("even = ${mutualRecursion.even(deepTree)}")
    check(dn == n)

    fun binaryTree(k: Int): Tree? =
    if (k == 0) null else Tree(binaryTree(k - 1), binaryTree(k - 1))
    val k = 20
    println("=== binaryTree($k)")
    val binaryTree = binaryTree(k)

    val dk = depth(binaryTree)
    println("depth = $dk")
    println("even = ${mutualRecursion.even(binaryTree)}")
    check(dk == k)

    // Mix of call & callRecursive
    val mix = object {
    val b: DeepRecursiveFunction<Int, String> = deepRecursive { i -> "b$i" }
    lateinit var a: DeepRecursiveFunction<Int, String>

    init {
    a = deepRecursive { i ->
    when (i) {
    0 -> b(1) + callRecursive(2) + a(3)
    else -> "a$i"
    }
    }
    }
    }
    val s = mix.a(0)
    println("s = $s")
    check(s == "b1a2a3")
    }
  9. elizarov revised this gist Jun 1, 2019. 1 changed file with 4 additions and 4 deletions.
    8 changes: 4 additions & 4 deletions RecursiveFunction.kt
    Original file line number Diff line number Diff line change
    @@ -6,10 +6,10 @@ import kotlin.math.*
    * Defines recursive function that keeps its stack on the heap,
    * which allows very deep recursive computations that do not use the actual call stack.
    *
    * The function takes one parameter of type [T] and returns a result of type [R]
    * The [block] of code defines the body of a recursive function and in this block
    * [rec][RecursiveFunctionScope] function should be used to make a recursive call
    * the declared function.
    * The function takes one parameter of type [T] and returns a result of type [R].
    * The [block] of code defines the body of a recursive function. In this block
    * [rec][RecursiveFunctionScope.rec] function should be used to make a recursive call
    * to the declared function.
    *
    * For example, take a look at the following recursive tree class and a deeply
    * recursive instance of this tree with a million nodes:
  10. elizarov revised this gist Jun 1, 2019. 1 changed file with 4 additions and 2 deletions.
    6 changes: 4 additions & 2 deletions RecursiveFunction.kt
    Original file line number Diff line number Diff line change
    @@ -8,7 +8,8 @@ import kotlin.math.*
    *
    * The function takes one parameter of type [T] and returns a result of type [R]
    * The [block] of code defines the body of a recursive function and in this block
    * [rec][RecursiveFunctionScope] function is used to make a recursive call.
    * [rec][RecursiveFunctionScope] function should be used to make a recursive call
    * the declared function.
    *
    * For example, take a look at the following recursive tree class and a deeply
    * recursive instance of this tree with a million nodes:
    @@ -55,7 +56,8 @@ fun <T, R> recursiveFunction(block: suspend RecursiveFunctionScope<T, R>.(T) ->
    @RestrictsSuspension
    abstract class RecursiveFunctionScope<T, R> {
    /**
    * Makes recursive call to this [recursiveFunction].
    * Makes recursive call to this [recursiveFunction] putting the call activation frame on the heap,
    * as opposed to the actual call stack that is used by a regular recursive call.
    */
    abstract suspend fun rec(value: T): R
    }
  11. elizarov created this gist Jun 1, 2019.
    122 changes: 122 additions & 0 deletions RecursiveFunction.kt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,122 @@
    import kotlin.coroutines.*
    import kotlin.coroutines.intrinsics.*
    import kotlin.math.*

    /**
    * Defines recursive function that keeps its stack on the heap,
    * which allows very deep recursive computations that do not use the actual call stack.
    *
    * The function takes one parameter of type [T] and returns a result of type [R]
    * The [block] of code defines the body of a recursive function and in this block
    * [rec][RecursiveFunctionScope] function is used to make a recursive call.
    *
    * For example, take a look at the following recursive tree class and a deeply
    * recursive instance of this tree with a million nodes:
    *
    * ```
    * class Tree(val left: Tree? = null, val right: Tree? = null)
    * val deepTree = generateSequence(Tree()) { Tree(it) }.take(1_000_000).last()
    * ```
    *
    * A regular recursive function can be defined to compute a depth of a tree:
    *
    * ```
    * fun depth(t: Tree?): Int =
    * if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1
    * ```
    *
    * If this `depth` function is called for a `deepTree` it will produce [StackOverflowError] because
    * of deep recursion unless a very large stack size is used.
    *
    * However, the `depth` function can be rewritten using `recursiveFunction` in the following way so that
    * it successfully computes `depth(deepTree)`:
    *
    * ```
    * val depth = recursiveFunction<Tree?, Int> { t ->
    * if (t == null) 0 else max(rec(t.left), rec(t.right)) + 1
    * }
    * ```
    *
    * @param [T] function parameter type.
    * @param [R] function result type.
    * @param block function body.
    */
    fun <T, R> recursiveFunction(block: suspend RecursiveFunctionScope<T, R>.(T) -> R): (T) -> R =
    { value ->
    RecursiveFunctionImpl(block, value).call()
    }

    /**
    * A scope class for [recursiveFunction] declaration that defines [rec] function.
    *
    * @param [T] function parameter type.
    * @param [R] function result type.
    */
    @RestrictsSuspension
    abstract class RecursiveFunctionScope<T, R> {
    /**
    * Makes recursive call to this [recursiveFunction].
    */
    abstract suspend fun rec(value: T): R
    }

    private class RecursiveFunctionImpl<T, R>(
    private val block: suspend RecursiveFunctionScope<T, R>.(T) -> R,
    private var value: T? = null
    ) : RecursiveFunctionScope<T, R>(), Continuation<R> {
    private var result: Result<Any?> = Result.success(null)
    private var cont: Continuation<R>? = this

    override val context: CoroutineContext
    get() = EmptyCoroutineContext

    override fun resumeWith(result: Result<R>) {
    this.cont = null
    this.result = result
    }

    override suspend fun rec(value: T): R = suspendCoroutineUninterceptedOrReturn { cont ->
    this.cont = cont
    this.value = value
    COROUTINE_SUSPENDED
    }

    @Suppress("UNCHECKED_CAST")
    fun call(): R {
    val f = block as Function3<Any?, T?, Continuation<R>?, Any?>
    while (true) {
    val cont = this.cont
    ?: return (result as Result<R>).getOrThrow()
    // This is block.startCoroutine(this, value, cont)
    val r = try {
    f(this, value, cont)
    } catch (e: Throwable) {
    cont.resumeWithException(e)
    continue
    }
    if (r !== COROUTINE_SUSPENDED)
    cont.resume(r as R)
    }
    }
    }

    // ============== Test code ==============

    class Tree(val left: Tree? = null, val right: Tree? = null)

    fun main() {
    val n = 1_000_000
    val deepTree = generateSequence(Tree()) { Tree(it) }.take(n).last()

    // Regular code: Stack overflow error
    // fun depth(t: Tree?): Int =
    // if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1

    val depth = recursiveFunction<Tree?, Int> { t ->
    if (t == null) 0 else max(rec(t.left), rec(t.right)) + 1
    }

    val d = depth(deepTree)
    println("depth = $d")
    check(d == n)
    }