Last active
December 9, 2021 16:20
-
-
Save a3lem/754a6bc9ae6ba967c5b92a38f4e48a5f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import typing | |
| from typing import Any | |
| import string | |
| import datasets | |
| from datasets import Dataset as HFDataset | |
| import torch | |
| from torch import Tensor | |
| from torch.utils.data import ( | |
| BatchSampler, DataLoader, Dataset, RandomSampler | |
| ) | |
| # Don't cache dataset intermediate transforms. | |
| datasets.set_caching_enabled(False) | |
| # Previously 'columnized' (i.e. keys correspond to columns). | |
| toy_dataset = { | |
| "name": ["John", "Jacob", "Jingleheimer", "Schmidt"], | |
| "is_my_name_too": [True, True, True, True] | |
| } | |
| # In practice, you'd probaly write a data loading script instead of using `from_dict()`. | |
| # Aside: I wish there was an equivalent to tf.data.Dataset.from_generator(). | |
| dataset: HFDataset = HFDataset.from_dict(toy_dataset) | |
| # Simple vocab for char->idx mapping. | |
| lower_alpha_vocab: dict[str, int] = {"<PAD>": 0} | |
| lower_alpha_vocab.update( | |
| zip(string.ascii_lowercase, range(1, 26 + 1)) | |
| ) | |
| print(lower_alpha_vocab) # -> {'<PAD>': 0, 'a': 0, 'b': 1, ..., 'y': 24, 'z': 25} | |
| def char_tokenize(strings: list[str]): | |
| """Tokenize characters.""" | |
| return [list(s) for s in strings] | |
| def numericalize_and_pad(seqs_batch: list[list[str]], vocab: dict[str, int]): | |
| """Numericalize sequences in batch and pad result, w pad_idx = `vocab["<PAD>"]`.""" | |
| max_len = max(map(len, seqs_batch)) | |
| pad_idx = vocab["<PAD>"] # slightly icky | |
| padded = [] | |
| for seq in seqs_batch: | |
| # Numericalize. | |
| idxs = [vocab[tok] for tok in seq] | |
| idxs.extend(pad_idx for _ in range(max_len - len(idxs))) | |
| padded.append(idxs) | |
| return padded | |
| dataset: HFDataset = ( | |
| dataset | |
| # Data cleaning | |
| .map( | |
| lambda batch: {"name": [s.lower() for s in batch["name"]]}, | |
| batched=True | |
| ) | |
| # Tokenize | |
| .map( | |
| lambda batch: {"tokens": char_tokenize(batch["name"])}, | |
| batched=True | |
| ) | |
| ) | |
| print(dataset[0]) # -> {'is_my_name_too': True, 'name': 'john', 'tokens': ['j', 'o', 'h', 'n']} | |
| TensorDict = dict[str, Tensor] | |
| def collate(batch: dict[str, list[Any]], vocab: dict[str, int]) -> TensorDict: | |
| return { | |
| "token_idxs": torch.tensor( | |
| numericalize_and_pad(batch["tokens"], vocab), | |
| dtype=torch.int64 | |
| ), | |
| "is_my_name_too": torch.tensor(batch["is_my_name_too"]) | |
| } | |
| dataset: HFDataset = dataset.with_transform( | |
| # Batch collation now happens *outside* of DataLoader. | |
| lambda batch: collate(batch, vocab=lower_alpha_vocab) | |
| ) | |
| # HF dataset suports fancy indexing! | |
| # In: | |
| dataset[[1,2]] | |
| # Out: | |
| # {'token_idxs': tensor([[10, 1, 3, 15, 2, 0, 0, 0, 0, 0, 0, 0], | |
| # [10, 9, 14, 7, 12, 5, 8, 5, 9, 13, 5, 18]]), | |
| # 'is_my_name_too': tensor([True, True])} | |
| # Using HF Dataset with PyTorch DataLoader | |
| sampler = BatchSampler( | |
| sampler=RandomSampler(dataset), | |
| batch_size=2, | |
| drop_last=False | |
| ) | |
| loader = DataLoader( | |
| # Tell type checker to treat 🤗-Dataset as a torch.utils.Dataset. | |
| typing.cast(Dataset, dataset), | |
| sampler=sampler, | |
| # Important! Ensures that batches are gotten from HF-Dataset by fancy indexing, | |
| # e.g. batch = dataset[[3,1,2,4]] instead of batch = [dataset[i] for i in [3, 1, 2, 4]]. | |
| # This lets us perform collation using `with_transform()`. | |
| batch_size=None, | |
| collate_fn=None | |
| ) | |
| # Grab a first mini-batch | |
| # In: | |
| next(iter(loader)) | |
| # Out: | |
| # {'token_idxs': tensor([[19, 3, 8, 13, 9, 4, 20], | |
| # [10, 15, 8, 14, 0, 0, 0]]), | |
| # 'is_my_name_too': tensor([True, True])} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment