Skip to content

Instantly share code, notes, and snippets.

@vincentqb
Last active March 16, 2020 14:17
Show Gist options
  • Select an option

  • Save vincentqb/f9ae09fd55c1e493cadec0067851bedf to your computer and use it in GitHub Desktop.

Select an option

Save vincentqb/f9ae09fd55c1e493cadec0067851bedf to your computer and use it in GitHub Desktop.
Strided Buffer
from itertools import repeat
class StridedBuffer:
def __init__(self, generator, stride, length):
self._generator = generator
self._stride = stride
self._length = length
self._buffer = [[] for _ in repeat(None, stride)]
self._mod = 0
def __iter__(self):
return self
def __next__(self):
while len(self._buffer[0]) < self._length:
item = next(self._generator)
self._buffer[self._mod].append(item)
self._mod = (self._mod + 1) % self._stride
item = self._buffer.pop(0)
self._buffer.append([])
return item
dataset = iter(list(range(10)))
for batch in StridedBuffer(dataset, 2, 4):
# Loop waits until there is a full batch ready
print(batch)
# outputs = model(batch) # usual pytorch nn.Module here
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment