Created
May 28, 2021 09:54
-
-
Save jchacks/c80acabe3c039d2f3f83ba6f38b19df6 to your computer and use it in GitHub Desktop.
Machine Learning flat batch creation benchmark script.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import time | |
| import timeit | |
| batch_size = 5000 | |
| datasets = 4 | |
| data = [np.ones(100)] * datasets | |
| def opt1(): | |
| batches = [] | |
| for i in range(batch_size): | |
| batches.append(data) | |
| batches = list(map(np.stack,zip(*batches))) | |
| return batches | |
| def opt2(): | |
| batches = [[] for _ in range(datasets)] | |
| for i in range(batch_size): | |
| for j in range(datasets): | |
| batches[j].append(data[j]) | |
| batches = list(map(np.stack,batches)) | |
| return batches | |
| def opt3(): | |
| batches = [] | |
| for i in range(batch_size): | |
| batches.append(data) | |
| batches = list(zip(*batches)) | |
| for i in range(datasets): | |
| batches[i] = np.stack(batches[i]) | |
| return batches | |
| def opt4(): | |
| batches = [[] for _ in range(datasets)] | |
| for i in range(batch_size): | |
| for j in range(datasets): | |
| batches[j].append(data[j]) | |
| for i in range(datasets): | |
| batches[i] = np.stack(batches[i]) | |
| return batches | |
| def opt5(): | |
| batches = [np.zeros((batch_size, 100)) for _ in range(datasets)] | |
| for i in range(batch_size): | |
| for j in range(datasets): | |
| batches[j][i] = data[j] | |
| return batches | |
| _batches = [np.zeros((batch_size, 100)) for _ in range(datasets)] | |
| def opt6(): | |
| batches = [b.copy() for b in _batches] | |
| for i in range(batch_size): | |
| for j in range(datasets): | |
| batches[j][i] = data[j] | |
| return batches | |
| len(opt1()), opt1()[0].shape | |
| len(opt2()), opt2()[0].shape | |
| len(opt5()), opt5()[0].shape | |
| print(timeit.timeit('opt1()', globals=globals(), number=100)) | |
| print(timeit.timeit('opt2()', globals=globals(), number=100)) | |
| print(timeit.timeit('opt3()', globals=globals(), number=100)) | |
| print(timeit.timeit('opt4()', globals=globals(), number=100)) | |
| print(timeit.timeit('opt5()', globals=globals(), number=100)) | |
| print(timeit.timeit('opt6()', globals=globals(), number=100)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment