Last active
March 25, 2024 00:40
-
-
Save elizarov/861dda8c3e8c5ee36eaa6db4ad996568 to your computer and use it in GitHub Desktop.
Defines recursive function that keeps its stack on the heap (productized version)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import kotlin.coroutines.* | |
| import kotlin.coroutines.intrinsics.* | |
| import kotlin.math.* | |
| /** | |
| * Defines recursive function that keeps its stack on the heap, | |
| * which allows very deep recursive computations that do not use the actual call stack. | |
| * | |
| * The function takes one parameter of type [T] and returns a result of type [R] | |
| * The [block] of code defines the body of a recursive function and in this block | |
| * [rec][RecursiveFunctionScope] function is used to make a recursive call. | |
| * | |
| * For example, take a look at the following recursive tree class and a deeply | |
| * recursive instance of this tree with a million nodes: | |
| * | |
| * ``` | |
| * class Tree(val left: Tree? = null, val right: Tree? = null) | |
| * val deepTree = generateSequence(Tree()) { Tree(it) }.take(1_000_000).last() | |
| * ``` | |
| * | |
| * A regular recursive function can be defined to compute a depth of a tree: | |
| * | |
| * ``` | |
| * fun depth(t: Tree?): Int = | |
| * if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1 | |
| * ``` | |
| * | |
| * If this `depth` function is called for a `deepTree` it will produce [StackOverflowError] because | |
| * of deep recursion unless a very large stack size is used. | |
| * | |
| * However, the `depth` function can be rewritten using `recursiveFunction` in the following way so that | |
| * it successfully computes `depth(deepTree)`: | |
| * | |
| * ``` | |
| * val depth = recursiveFunction<Tree?, Int> { t -> | |
| * if (t == null) 0 else max(rec(t.left), rec(t.right)) + 1 | |
| * } | |
| * ``` | |
| * | |
| * @param [T] function parameter type. | |
| * @param [R] function result type. | |
| * @param block function body. | |
| */ | |
| fun <T, R> recursiveFunction(block: suspend RecursiveFunctionScope<T, R>.(T) -> R): (T) -> R = | |
| { value -> | |
| RecursiveFunctionImpl(block, value).call() | |
| } | |
| /** | |
| * A scope class for [recursiveFunction] declaration that defines [rec] function. | |
| * | |
| * @param [T] function parameter type. | |
| * @param [R] function result type. | |
| */ | |
| @RestrictsSuspension | |
| abstract class RecursiveFunctionScope<T, R> { | |
| /** | |
| * Makes recursive call to this [recursiveFunction]. | |
| */ | |
| abstract suspend fun rec(value: T): R | |
| } | |
| private class RecursiveFunctionImpl<T, R>( | |
| private val block: suspend RecursiveFunctionScope<T, R>.(T) -> R, | |
| private var value: T? = null | |
| ) : RecursiveFunctionScope<T, R>(), Continuation<R> { | |
| private var result: Result<Any?> = Result.success(null) | |
| private var cont: Continuation<R>? = this | |
| override val context: CoroutineContext | |
| get() = EmptyCoroutineContext | |
| override fun resumeWith(result: Result<R>) { | |
| this.cont = null | |
| this.result = result | |
| } | |
| override suspend fun rec(value: T): R = suspendCoroutineUninterceptedOrReturn { cont -> | |
| this.cont = cont | |
| this.value = value | |
| COROUTINE_SUSPENDED | |
| } | |
| @Suppress("UNCHECKED_CAST") | |
| fun call(): R { | |
| val f = block as Function3<Any?, T?, Continuation<R>?, Any?> | |
| while (true) { | |
| val cont = this.cont | |
| ?: return (result as Result<R>).getOrThrow() | |
| // This is block.startCoroutine(this, value, cont) | |
| val r = try { | |
| f(this, value, cont) | |
| } catch (e: Throwable) { | |
| cont.resumeWithException(e) | |
| continue | |
| } | |
| if (r !== COROUTINE_SUSPENDED) | |
| cont.resume(r as R) | |
| } | |
| } | |
| } | |
| // ============== Test code ============== | |
| class Tree(val left: Tree? = null, val right: Tree? = null) | |
| fun main() { | |
| val n = 1_000_000 | |
| val deepTree = generateSequence(Tree()) { Tree(it) }.take(n).last() | |
| // Regular code: Stack overflow error | |
| // fun depth(t: Tree?): Int = | |
| // if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1 | |
| val depth = recursiveFunction<Tree?, Int> { t -> | |
| if (t == null) 0 else max(rec(t.left), rec(t.right)) + 1 | |
| } | |
| val d = depth(deepTree) | |
| println("depth = $d") | |
| check(d == n) | |
| } |
Author
You can use it under APL 2.0. It will be merged into Kotlin stdlib. See JetBrains/kotlin#3398
Thank you!
The link to the related blogpost: https://medium.com/@elizarov/deep-recursion-with-coroutines-7c53e15993e3
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice!
I would like to use this in an (open-source) project. Would it be possible to add a license to it?