# lens = [3, 5, 4] # What we want: # mask = [[1, 1, 1, 0, 0], # [1, 1, 1, 1, 1], # [1, 1, 1, 1, 0]] # https://stackoverflow.com/questions/53403306/how-to-batch-convert-sentence-lengths-to-masks-in-pytorch def len_to_mask(lens: np.ndarray, seq_len: Optional[int] = None) -> np.ndarray: if seq_len is None: seq_len = max(lens) return np.arange(seq_len)[None, :] < lens[:, None]