Created
June 3, 2020 15:50
-
-
Save Simsso/aef9fe3842952659065f91f5c0bd3827 to your computer and use it in GitHub Desktop.
TensorFlow code for performing a lookup of values in matrix rows using the indices stored in a vector.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| def vec_lookup_in_mat(mat: tf.Tensor, vec: tf.Tensor): | |
| """ | |
| Performs a lookup of values in the matrix rows using the indices stored in the vector. | |
| For example, given the matrix [[1, 2, 3], [4, 5, 6]] and the vector [0, 2], the function would return a vector | |
| consisting of the entry with index zero in the first row and index 2 in the second row, i.e. [0, 6]. | |
| :param mat: Tensor of shape [m, n, ...], values will be looked up from it; dtype any | |
| :param vec: Tensor of shape [m], indices for each row of the matrix; dtype int32 | |
| :return: Tensor of shape [m, ...], created as return[i] = mat[i][vec[i]] | |
| """ | |
| m = mat.get_shape()[0] | |
| if vec.get_shape()[0] != m: | |
| raise ValueError("Vec shape must match number of matrix rows") | |
| lookup_indices = tf.stack((tf.range(m), vec), axis=1) | |
| out = tf.gather_nd(mat, lookup_indices) | |
| return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment