Skip to content

Instantly share code, notes, and snippets.

@thewisenerd
Created June 11, 2025 23:16
Show Gist options
  • Select an option

  • Save thewisenerd/60be82e9dc2f611ba6a540e008579314 to your computer and use it in GitHub Desktop.

Select an option

Save thewisenerd/60be82e9dc2f611ba6a540e008579314 to your computer and use it in GitHub Desktop.

Revisions

  1. thewisenerd created this gist Jun 11, 2025.
    81 changes: 81 additions & 0 deletions async-worker-pool-generator.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,81 @@
    import asyncio
    import typing
    from dataclasses import dataclass

    # NOTE: this effectively allows for unbounded coroutine creation
    # gated by the semaphore concurrency. check memory.

    T = typing.TypeVar("T")
    R = typing.TypeVar("R")
    C = typing.TypeVar("C")


    class Result(typing.Generic[T]):
    def __init__(self, ok: bool, value: T | None, reason: str | None = None):
    self.ok = ok
    self.value = value
    self.reason = reason

    @staticmethod
    def success(value: T) -> "Result[T]":
    return Result(ok=True, value=value, reason=None)

    @staticmethod
    def error(reason: str) -> "Result[T]":
    return Result(ok=False, value=None, reason=reason)


    @dataclass
    class TaskResult(typing.Generic[T, R, C]):
    task: T
    result: Result[R] | None
    children: list[C]

    TaskType = typing.Union[RootTask, ChildTask] # example..

    async def worker(
    sem: asyncio.Semaphore,
    task: TaskType,
    ) -> TaskResult[TaskType, Result[Project], TaskType]:
    async with sem:
    return await task_impl(task)


    async def main_impl(
    concurrency: int
    ) -> typing.AsyncGenerator[Result[Project], None]:
    if concurrency < 1:
    raise ValueError("concurrency must be >= 1")

    sem = asyncio.Semaphore(concurrency)
    pending = set()
    total = 0
    completed = 0

    task = asyncio.create_task(worker(sem, RootTask()))
    pending.add(task)
    total += 1
    pbar = tqdm.tqdm(total=1, bar_format='[{bar}] {n_fmt}/{total_fmt}')

    while pending:
    done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
    for task in done:
    task_result = await task
    completed += 1
    pbar.update(1)

    if not isinstance(task_result, TaskResult):
    raise TypeError(f"Expected TaskResult, got {type(task_result)}")

    if task_result.result is not None:
    yield task_result.result

    if task_result.children:
    for child_task in task_result.children:
    new_task = asyncio.create_task(worker(sem, child_task))
    pending.add(new_task)
    total += 1

    pbar.total = total

    pbar.close()