Created
February 20, 2024 08:23
-
-
Save wilzh40/9a60801b335dde1685838d831c5396b5 to your computer and use it in GitHub Desktop.
Cvar threadpool
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import concurrent.futures | |
| import typing | |
| from collections import deque | |
| import threading | |
| import os | |
| from collections.abc import Iterable, Callable | |
| import time | |
| import random | |
| class ThreadPool(): | |
| def __init__(self, num_workers: int = None): | |
| pass | |
| def map(self, fn: Callable, iterable: Iterable, kwargs: dict = {}): | |
| pass | |
| def submit(self, fn: Callable, args: tuple = (), kwargs: dict = {}): | |
| pass | |
| def shutdown(self, wait=True): | |
| pass | |
| class Task(): | |
| def __init__(self, fn: Callable, args: tuple = (), kwargs: dict = {}): | |
| self.closure = lambda: fn(*args, **kwargs) | |
| self.result = None | |
| self.error = None | |
| def run(self): | |
| self.result = self.closure() | |
| return self.result | |
| class SemaphoreThreadPool(ThreadPool): | |
| def __init__(self, num_workers: int = None, limit=50000): | |
| if num_workers == None: | |
| num_workers = min(4, max(32, os.cpu_count())) | |
| self.mutex = threading.Lock() | |
| self.min_sem = threading.Semaphore(0) | |
| self.max_sem = threading.Semaphore(limit) | |
| class CvarThreadPool(ThreadPool): | |
| def __init__(self, num_workers: int = None, limit=50000): | |
| if num_workers == None: | |
| num_workers = min(4, max(32, os.cpu_count())) | |
| self.mutex = threading.Lock() | |
| self.cvar = threading.Condition(self.mutex) | |
| self.shutdown_event = threading.Event() | |
| self.queue = deque() | |
| self.limit = limit | |
| self.pending_tasks = 0 | |
| prefix = "CvarThreadPool" | |
| self.threads = set(threading.Thread( | |
| target=self._work_loop, name=prefix + str(i)) for i in range(num_workers)) | |
| for thread in self.threads: | |
| thread.start() | |
| def _work_loop(self): | |
| while True: | |
| with self.mutex: | |
| while not self.queue: | |
| self.cvar.wait() | |
| if self.shutdown_event.is_set(): | |
| return | |
| assert len(self.queue) > 0 | |
| task = self.queue.popleft() | |
| try: | |
| task.run() | |
| except Exception as e: | |
| task.error = e | |
| finally: | |
| with self.mutex: | |
| self.pending_tasks -= 1 | |
| if self.pending_tasks == 0: | |
| self.cvar.notify_all() | |
| def map(self, fn: Callable, iterable: Iterable, kwargs: dict = {}): | |
| # The queue is full, let's wait. | |
| results = [] | |
| for arg in iterable: | |
| task = self.submit(fn, args=(arg,), kwargs=kwargs) | |
| results.append(task) | |
| with self.mutex: | |
| while self.pending_tasks > 0: | |
| self.cvar.wait() | |
| # All tasks are completed. | |
| print(f"All tasks completed: {self.pending_tasks}") | |
| return results | |
| def submit(self, fn, args: tuple = (), kwargs: dict = {}): | |
| with self.mutex: | |
| while len(self.queue) > self.limit: | |
| # can be a different cvar for performance reasons? | |
| self.cvar.wait() | |
| task = Task(fn, args, kwargs) | |
| self.queue.append(task) | |
| self.pending_tasks += 1 | |
| # print(f"Pending tasks: {self.pending_tasks}") | |
| self.cvar.notify() | |
| def shutdown(self, wait=True): | |
| self.shutdown_event.set() | |
| print("Shutting down.") | |
| with self.mutex: | |
| self.queue = [None] | |
| self.cvar.notify_all() | |
| # Dummy submission to fall through the work loop. | |
| if wait: | |
| for thread in self.threads: | |
| thread.join() | |
| def cpu_bound(duration): | |
| start_time = time.time() | |
| while time.time() < start_time + duration: | |
| # Spin loop: CPU heavy! | |
| pass | |
| return duration | |
| # Syscal to kernel for sleep, representation of I/O bound task | |
| def io_bound(duration): | |
| time.sleep(duration) | |
| return duration | |
| def profile_pool(num_workers_list: list[int], fn: Callable, num_tasks: int = 50, type: str = "cvar"): | |
| times = {} | |
| for num_workers in num_workers_list: | |
| if type == "default": | |
| tpm = concurrent.futures.ThreadPoolExecutor( | |
| max_workers=num_workers) | |
| else: | |
| tpm = CvarThreadPool(num_workers=num_workers) | |
| start_time = time.perf_counter() | |
| tpm.map(fn, [0.1] * num_tasks) | |
| tpm.shutdown() | |
| end_time = time.perf_counter() | |
| assert threading.active_count( | |
| ) == 1, f"Only one active thread (the main thread) must remain, got {threading.active_count()}" | |
| times[num_workers] = end_time - start_time | |
| return times | |
| if __name__ == "__main__": | |
| # Default IO: {2: 13.275235374923795, 4: 6.655313625000417, 8: 3.3173186670755967, 16: 1.6627995839808136, 32: 0.8406297919573262, 64: 0.42237862502224743, 128: 0.21964483300689608, 256: 0.12271529098507017, 512: 0.11788716702722013} | |
| # Default CPU: {2: 12.96525054203812, 4: 6.929109917022288, 8: 4.171173915965483, 16: 3.3885942089837044, 32: 3.4872594580519944, 64: 3.377452834043652, 128: 3.634125957963988, 256: 3.5097480000695214, 512: 3.2952734159771353} | |
| # IO Bound: {2: 13.285975374979898, 4: 6.644823250011541, 8: 3.3390898749930784, 16: 1.6570523750269786, 32: 0.8338112080236897, 64: 0.4214848750270903, 128: 0.22665741597302258, 256: 0.12575916689820588, 512: 0.14143420895561576} | |
| # CPU Bound: {2: 12.980181750026532, 4: 6.9501094170846045, 8: 4.130237541976385, 16: 3.182236749911681, 32: 2.6954647090751678, 64: 2.2214373329188675, 128: 1.8402691250666976, 256: 1.6856321659870446, 512: 1.665079541038721 | |
| default_io_times = profile_pool( | |
| [2**i for i in range(1, 10)], io_bound, 2 ** 8, "default") | |
| default_cpu_times = profile_pool( | |
| [2**i for i in range(1, 10)], cpu_bound, 2 ** 8, "default") | |
| io_times = profile_pool([2**i for i in range(1, 10)], io_bound, 2 ** 8) | |
| cpu_times = profile_pool([2**i for i in range(1, 10)], cpu_bound, 2 ** 8) | |
| print(f"Default IO: {default_io_times}") | |
| print(f"Default CPU: {default_cpu_times}") | |
| print(f"IO Bound: {io_times}") | |
| print(f"CPU Bound: {cpu_times}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment