Skip to content

Instantly share code, notes, and snippets.

@jchacks
Created May 28, 2021 09:54
Show Gist options
  • Select an option

  • Save jchacks/c80acabe3c039d2f3f83ba6f38b19df6 to your computer and use it in GitHub Desktop.

Select an option

Save jchacks/c80acabe3c039d2f3f83ba6f38b19df6 to your computer and use it in GitHub Desktop.
Machine Learning flat batch creation benchmark script.
import numpy as np
import time
import timeit
batch_size = 5000
datasets = 4
data = [np.ones(100)] * datasets
def opt1():
batches = []
for i in range(batch_size):
batches.append(data)
batches = list(map(np.stack,zip(*batches)))
return batches
def opt2():
batches = [[] for _ in range(datasets)]
for i in range(batch_size):
for j in range(datasets):
batches[j].append(data[j])
batches = list(map(np.stack,batches))
return batches
def opt3():
batches = []
for i in range(batch_size):
batches.append(data)
batches = list(zip(*batches))
for i in range(datasets):
batches[i] = np.stack(batches[i])
return batches
def opt4():
batches = [[] for _ in range(datasets)]
for i in range(batch_size):
for j in range(datasets):
batches[j].append(data[j])
for i in range(datasets):
batches[i] = np.stack(batches[i])
return batches
def opt5():
batches = [np.zeros((batch_size, 100)) for _ in range(datasets)]
for i in range(batch_size):
for j in range(datasets):
batches[j][i] = data[j]
return batches
_batches = [np.zeros((batch_size, 100)) for _ in range(datasets)]
def opt6():
batches = [b.copy() for b in _batches]
for i in range(batch_size):
for j in range(datasets):
batches[j][i] = data[j]
return batches
len(opt1()), opt1()[0].shape
len(opt2()), opt2()[0].shape
len(opt5()), opt5()[0].shape
print(timeit.timeit('opt1()', globals=globals(), number=100))
print(timeit.timeit('opt2()', globals=globals(), number=100))
print(timeit.timeit('opt3()', globals=globals(), number=100))
print(timeit.timeit('opt4()', globals=globals(), number=100))
print(timeit.timeit('opt5()', globals=globals(), number=100))
print(timeit.timeit('opt6()', globals=globals(), number=100))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment