Last active
December 6, 2023 15:14
-
-
Save MaugrimEP/0f608a9cd2e6b30e0291bd6b591536bf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pickle | |
| import lmdb | |
| from torch.utils.data import Dataset, DataLoader | |
| import os | |
| from tqdm import tqdm | |
| def _dumps_pickle(obj) -> bytes: | |
| return pickle.dumps(obj) | |
| def lmdb2lmdbpickle( | |
| src_dataset: Dataset, | |
| dest_path: str, | |
| name: str = "train", | |
| write_frequency: int = 500, | |
| num_workers: int = 16, | |
| map_size: int = int(8e9), | |
| pickle_func=_dumps_pickle, | |
| ) -> None: | |
| data_loader = DataLoader( | |
| src_dataset, collate_fn=lambda x: x, num_workers=num_workers | |
| ) | |
| print(f"{len(src_dataset)=}", f"{len(data_loader)=}") | |
| lmdb_path = os.path.join(dest_path, name) | |
| isdir = os.path.isdir(lmdb_path) | |
| print(f"Generate LMDB to {lmdb_path}") | |
| with lmdb.open( | |
| lmdb_path, | |
| subdir=isdir, | |
| map_size=map_size, | |
| readonly=False, | |
| meminit=False, | |
| map_async=True, | |
| ) as db: | |
| txn = db.begin(write=True) | |
| for idx, batch in enumerate(tqdm(data_loader)): | |
| data = batch[0] | |
| txn.put(f"{idx}".encode(), pickle_func(data)) | |
| if idx % write_frequency == 0: | |
| print(f"writing: [{idx+1}/{len(data_loader)}]") | |
| txn.commit() | |
| txn = db.begin(write=True) | |
| # finish iterating through the dataset | |
| txn.commit() | |
| keys = [str(k) for k in range(len(data_loader))] | |
| with db.begin(write=True) as txn: | |
| txn.put(b"__keys__", pickle_func(keys)) | |
| txn.put(b"__len__", pickle_func(len(keys))) | |
| print("Flushing database ....") | |
| db.sync() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment