Skip to content

Instantly share code, notes, and snippets.

@nicolasdespres
Created March 8, 2017 08:11
Show Gist options
  • Select an option

  • Save nicolasdespres/81689421f56b86a315a81f19d301508a to your computer and use it in GitHub Desktop.

Select an option

Save nicolasdespres/81689421f56b86a315a81f19d301508a to your computer and use it in GitHub Desktop.

Revisions

  1. nicolasdespres created this gist Mar 8, 2017.
    111 changes: 111 additions & 0 deletions iter_shuffle_batch_window.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,111 @@
    class iter_shuffle_batch_window(Iterator):
    """Iterate a window over an in-memory sequence of data.
    Args:
    `window_size`: The size of the window (must be smaller than the data)
    `window_alignment`: How to align the window around the point: "left",
    "right", or "center".
    `shifts`: A list of indices to shifts the windows from. By default
    it is `[0]` but you can set it `[0, 1]` generate batch of
    inputs and targets windows.
    `packer`: A function to pack each batch of windows into a container.
    By default it is `list` but you can set it to `np.stack`
    to get a numpy array.
    """

    def __init__(self, data,
    batch_size=None,
    window_size=None,
    shuffle=True,
    allow_smaller_final_batch=False,
    num_cycles=1,
    window_alignment="left",
    shifts=None,
    packer=list):
    self.data = data
    if not isinstance(window_size, int):
    raise TypeError("window_size must be int, not {}"
    .format(type(window_size).__name__))
    if window_size <= 0:
    raise ValueError("window_size must be positive")
    self._window_size = window_size
    self._window_alignment = window_alignment
    self._shifts = [0] if shifts is None else shifts
    assert all(i >=0 for i in self._shifts), \
    "shifts value must be all positive or null"
    self._range = window_range(len(self.data),
    size=self._window_size + max(self._shifts),
    alignment=self._window_alignment)
    self._take = partial(window_at,
    size=self._window_size,
    alignment=self._window_alignment)

    self._pack = packer

    def reset(self):
    self._it = iter_shuffle_batch_range(
    self._range,
    batch_size=batch_size,
    shuffle=shuffle,
    allow_smaller_final_batch=allow_smaller_final_batch,
    num_cycles=num_cycles)
    self.reset = types.MethodType(reset, self)
    self.reset()

    def __len__(self):
    return len(self._it)

    @property
    def batch_size(self):
    """The size of each batch."""
    return self._it.batch_size

    @property
    def window_size(self):
    return self._window_size

    @property
    def window_alignment(self):
    return self._window_alignment

    @property
    def shifts(self):
    return self._shifts

    @property
    def allow_smaller_final_batch(self):
    return self._it.allow_smaller_final_batch

    @property
    def size(self):
    return self._it.size

    @property
    def num_cycles(self):
    return self._it.num_cycles

    @property
    def shuffle(self):
    return self._it.shuffle

    @property
    def steps_per_epoch(self):
    return self._it.steps_per_epoch

    @property
    def epoch(self):
    return self._it.epoch

    @property
    def step(self):
    return self._it.step

    def __next__(self):
    # Take a batch of indices pointing to the beginning of a window.
    # For each index we get a slice of our data starting from it.
    batch = next(self._it)
    windows = []
    for s in self._shifts:
    windows.append(self._pack(
    [self._take(self.data, i+s) for i in batch]))
    return windows