Skip to content

Instantly share code, notes, and snippets.

@Simsso
Created June 3, 2020 15:50
Show Gist options
  • Select an option

  • Save Simsso/aef9fe3842952659065f91f5c0bd3827 to your computer and use it in GitHub Desktop.

Select an option

Save Simsso/aef9fe3842952659065f91f5c0bd3827 to your computer and use it in GitHub Desktop.
TensorFlow code for performing a lookup of values in matrix rows using the indices stored in a vector.
import tensorflow as tf
def vec_lookup_in_mat(mat: tf.Tensor, vec: tf.Tensor):
"""
Performs a lookup of values in the matrix rows using the indices stored in the vector.
For example, given the matrix [[1, 2, 3], [4, 5, 6]] and the vector [0, 2], the function would return a vector
consisting of the entry with index zero in the first row and index 2 in the second row, i.e. [0, 6].
:param mat: Tensor of shape [m, n, ...], values will be looked up from it; dtype any
:param vec: Tensor of shape [m], indices for each row of the matrix; dtype int32
:return: Tensor of shape [m, ...], created as return[i] = mat[i][vec[i]]
"""
m = mat.get_shape()[0]
if vec.get_shape()[0] != m:
raise ValueError("Vec shape must match number of matrix rows")
lookup_indices = tf.stack((tf.range(m), vec), axis=1)
out = tf.gather_nd(mat, lookup_indices)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment