Skip to content

Instantly share code, notes, and snippets.

@ml-edu
Forked from lantiga/README.md
Created April 27, 2023 04:51
Show Gist options
  • Select an option

  • Save ml-edu/d272d9518c8d898a4910bd805288aabe to your computer and use it in GitHub Desktop.

Select an option

Save ml-edu/d272d9518c8d898a4910bd805288aabe to your computer and use it in GitHub Desktop.

Revisions

  1. @lantiga lantiga created this gist Feb 6, 2018.
    27 changes: 27 additions & 0 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,27 @@

    # Indexed convolutions

    A convolution operator over a 1D tensor (BxCxL), where a list of neighbors for each element is provided through a indices tensor (LxK), where K is the size of the convolution kernel. Each row of indices specifies the indices of the K neighbors of the corresponding element in the input. A -1 is handled like for zero padding.

    Note that the neighbors specified in indices are not relative, but rather absolute. They have to be specified for each of the elements of the output.

    A use case is for convolutions over non-square lattices, such as images on hexagonal lattices coming from Cherenkov telescopes (http://www.isdc.unige.ch/%7Elyard/FirstLight/FirstLight_slowHD.mov).

    Example:

    ```
    import torch
    # a 1D input of 5 elems
    input = torch.randn(1,1,5)
    # this specifies the indices of neighbors for
    # each elem of the input (a 3 elem kernel here)
    # A -1 corresponds to zero-padding
    indices = torch.ones(5,3).type(torch.LongTensor)
    weight = torch.randn(1,1,3)
    bias = torch.randn(1)
    output = torch.nn.functional.indexed_conv(input, indices, weight, bias)
    ```
    67 changes: 67 additions & 0 deletions indexed_conv.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,67 @@
    import torch
    from torch.autograd import Variable

    def prepare_mask(indices):

    padded = indices == -1
    indices[padded] = 0

    mask = torch.FloatTensor([1,0])
    mask = mask[..., padded.t().long()]

    return indices, mask


    def indexed_conv(input, weight, bias, indices, mask):

    nbatch = input.shape[0]
    output_width = indices.shape[0]
    out_chans, in_chans, ksize = weight.shape

    if isinstance(input, Variable):
    mask = Variable(mask)

    col = input[..., indices.t()] * mask
    col = col.view(nbatch, -1, output_width)

    weight_col = weight.view(out_chans, -1)

    out = torch.matmul(weight_col, col) + bias

    #print(col)
    #print(weight_col)

    return out


    if __name__ == '__main__':

    # input = torch.randn(1,2,5)
    # weight = torch.randn(1,2,3)
    # bias = torch.randn(1)
    # indices = (5 * torch.rand(4,3)).long()

    input = torch.ones(1,2,5)
    weight = torch.ones(1,2,3)
    bias = torch.zeros(1)
    indices = (5 * torch.rand(4,3)).long()
    indices[0,0] = -1

    indices, mask = prepare_mask(indices)

    print(input)
    print(indices)

    out = indexed_conv(input, weight, bias, indices, mask)

    input = Variable(input, requires_grad=True)
    weight = Variable(weight)
    bias = Variable(bias)

    out = indexed_conv(input, weight, bias, indices, mask)

    print(out)

    out.sum().backward()
    print(input.grad)