Skip to content

Instantly share code, notes, and snippets.

@wilzh40
Created February 20, 2024 08:23
Show Gist options
  • Select an option

  • Save wilzh40/9a60801b335dde1685838d831c5396b5 to your computer and use it in GitHub Desktop.

Select an option

Save wilzh40/9a60801b335dde1685838d831c5396b5 to your computer and use it in GitHub Desktop.
Cvar threadpool
import concurrent.futures
import typing
from collections import deque
import threading
import os
from collections.abc import Iterable, Callable
import time
import random
class ThreadPool():
def __init__(self, num_workers: int = None):
pass
def map(self, fn: Callable, iterable: Iterable, kwargs: dict = {}):
pass
def submit(self, fn: Callable, args: tuple = (), kwargs: dict = {}):
pass
def shutdown(self, wait=True):
pass
class Task():
def __init__(self, fn: Callable, args: tuple = (), kwargs: dict = {}):
self.closure = lambda: fn(*args, **kwargs)
self.result = None
self.error = None
def run(self):
self.result = self.closure()
return self.result
class SemaphoreThreadPool(ThreadPool):
def __init__(self, num_workers: int = None, limit=50000):
if num_workers == None:
num_workers = min(4, max(32, os.cpu_count()))
self.mutex = threading.Lock()
self.min_sem = threading.Semaphore(0)
self.max_sem = threading.Semaphore(limit)
class CvarThreadPool(ThreadPool):
def __init__(self, num_workers: int = None, limit=50000):
if num_workers == None:
num_workers = min(4, max(32, os.cpu_count()))
self.mutex = threading.Lock()
self.cvar = threading.Condition(self.mutex)
self.shutdown_event = threading.Event()
self.queue = deque()
self.limit = limit
self.pending_tasks = 0
prefix = "CvarThreadPool"
self.threads = set(threading.Thread(
target=self._work_loop, name=prefix + str(i)) for i in range(num_workers))
for thread in self.threads:
thread.start()
def _work_loop(self):
while True:
with self.mutex:
while not self.queue:
self.cvar.wait()
if self.shutdown_event.is_set():
return
assert len(self.queue) > 0
task = self.queue.popleft()
try:
task.run()
except Exception as e:
task.error = e
finally:
with self.mutex:
self.pending_tasks -= 1
if self.pending_tasks == 0:
self.cvar.notify_all()
def map(self, fn: Callable, iterable: Iterable, kwargs: dict = {}):
# The queue is full, let's wait.
results = []
for arg in iterable:
task = self.submit(fn, args=(arg,), kwargs=kwargs)
results.append(task)
with self.mutex:
while self.pending_tasks > 0:
self.cvar.wait()
# All tasks are completed.
print(f"All tasks completed: {self.pending_tasks}")
return results
def submit(self, fn, args: tuple = (), kwargs: dict = {}):
with self.mutex:
while len(self.queue) > self.limit:
# can be a different cvar for performance reasons?
self.cvar.wait()
task = Task(fn, args, kwargs)
self.queue.append(task)
self.pending_tasks += 1
# print(f"Pending tasks: {self.pending_tasks}")
self.cvar.notify()
def shutdown(self, wait=True):
self.shutdown_event.set()
print("Shutting down.")
with self.mutex:
self.queue = [None]
self.cvar.notify_all()
# Dummy submission to fall through the work loop.
if wait:
for thread in self.threads:
thread.join()
def cpu_bound(duration):
start_time = time.time()
while time.time() < start_time + duration:
# Spin loop: CPU heavy!
pass
return duration
# Syscal to kernel for sleep, representation of I/O bound task
def io_bound(duration):
time.sleep(duration)
return duration
def profile_pool(num_workers_list: list[int], fn: Callable, num_tasks: int = 50, type: str = "cvar"):
times = {}
for num_workers in num_workers_list:
if type == "default":
tpm = concurrent.futures.ThreadPoolExecutor(
max_workers=num_workers)
else:
tpm = CvarThreadPool(num_workers=num_workers)
start_time = time.perf_counter()
tpm.map(fn, [0.1] * num_tasks)
tpm.shutdown()
end_time = time.perf_counter()
assert threading.active_count(
) == 1, f"Only one active thread (the main thread) must remain, got {threading.active_count()}"
times[num_workers] = end_time - start_time
return times
if __name__ == "__main__":
# Default IO: {2: 13.275235374923795, 4: 6.655313625000417, 8: 3.3173186670755967, 16: 1.6627995839808136, 32: 0.8406297919573262, 64: 0.42237862502224743, 128: 0.21964483300689608, 256: 0.12271529098507017, 512: 0.11788716702722013}
# Default CPU: {2: 12.96525054203812, 4: 6.929109917022288, 8: 4.171173915965483, 16: 3.3885942089837044, 32: 3.4872594580519944, 64: 3.377452834043652, 128: 3.634125957963988, 256: 3.5097480000695214, 512: 3.2952734159771353}
# IO Bound: {2: 13.285975374979898, 4: 6.644823250011541, 8: 3.3390898749930784, 16: 1.6570523750269786, 32: 0.8338112080236897, 64: 0.4214848750270903, 128: 0.22665741597302258, 256: 0.12575916689820588, 512: 0.14143420895561576}
# CPU Bound: {2: 12.980181750026532, 4: 6.9501094170846045, 8: 4.130237541976385, 16: 3.182236749911681, 32: 2.6954647090751678, 64: 2.2214373329188675, 128: 1.8402691250666976, 256: 1.6856321659870446, 512: 1.665079541038721
default_io_times = profile_pool(
[2**i for i in range(1, 10)], io_bound, 2 ** 8, "default")
default_cpu_times = profile_pool(
[2**i for i in range(1, 10)], cpu_bound, 2 ** 8, "default")
io_times = profile_pool([2**i for i in range(1, 10)], io_bound, 2 ** 8)
cpu_times = profile_pool([2**i for i in range(1, 10)], cpu_bound, 2 ** 8)
print(f"Default IO: {default_io_times}")
print(f"Default CPU: {default_cpu_times}")
print(f"IO Bound: {io_times}")
print(f"CPU Bound: {cpu_times}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment