Skip to content

Instantly share code, notes, and snippets.

@MaugrimEP
Last active December 6, 2023 15:14
Show Gist options
  • Select an option

  • Save MaugrimEP/0f608a9cd2e6b30e0291bd6b591536bf to your computer and use it in GitHub Desktop.

Select an option

Save MaugrimEP/0f608a9cd2e6b30e0291bd6b591536bf to your computer and use it in GitHub Desktop.
import pickle
import lmdb
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm
def _dumps_pickle(obj) -> bytes:
return pickle.dumps(obj)
def lmdb2lmdbpickle(
src_dataset: Dataset,
dest_path: str,
name: str = "train",
write_frequency: int = 500,
num_workers: int = 16,
map_size: int = int(8e9),
pickle_func=_dumps_pickle,
) -> None:
data_loader = DataLoader(
src_dataset, collate_fn=lambda x: x, num_workers=num_workers
)
print(f"{len(src_dataset)=}", f"{len(data_loader)=}")
lmdb_path = os.path.join(dest_path, name)
isdir = os.path.isdir(lmdb_path)
print(f"Generate LMDB to {lmdb_path}")
with lmdb.open(
lmdb_path,
subdir=isdir,
map_size=map_size,
readonly=False,
meminit=False,
map_async=True,
) as db:
txn = db.begin(write=True)
for idx, batch in enumerate(tqdm(data_loader)):
data = batch[0]
txn.put(f"{idx}".encode(), pickle_func(data))
if idx % write_frequency == 0:
print(f"writing: [{idx+1}/{len(data_loader)}]")
txn.commit()
txn = db.begin(write=True)
# finish iterating through the dataset
txn.commit()
keys = [str(k) for k in range(len(data_loader))]
with db.begin(write=True) as txn:
txn.put(b"__keys__", pickle_func(keys))
txn.put(b"__len__", pickle_func(len(keys)))
print("Flushing database ....")
db.sync()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment