Skip to content

Instantly share code, notes, and snippets.

@anilkeshwani
Last active November 26, 2024 16:03
Show Gist options
  • Select an option

  • Save anilkeshwani/155b6379ca6a9cc55a237df2de6e0a58 to your computer and use it in GitHub Desktop.

Select an option

Save anilkeshwani/155b6379ca6a9cc55a237df2de6e0a58 to your computer and use it in GitHub Desktop.
Snippet showing the Llama 3.2 3B checkpoint structure (as an example of the splitting of models by tensor when saving checkpoints to Hugging Face repos; avoid exceeding 5GB max. even if this is not a limit)
#!/usr/bin/env python
"""
See: https://huggingface.co/docs/safetensors/en/index
"""
from pathlib import Path
from pprint import pp
from time import perf_counter
from safetensors import safe_open
from transformers import AutoModelForCausalLM, AutoTokenizer
t_start = perf_counter()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
t_tok = perf_counter() - t_start
print(f"tokenizer loaded in {t_tok:.2f} seconds")
t_1 = perf_counter()
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B")
t_2 = perf_counter() - t_1
print(f"model loaded in {t_2:.2f} seconds")
HF_CACHE_DIR = Path().home() / "Desktop" / "huggingface"
HF_HUB_DIR = HF_CACHE_DIR / "hub"
st1 = "models--meta-llama--Llama-3.2-3B/blobs/584d8d3e3f82f7964955174dfe5e3b1cf117a9d859f022cfdf7fcb884856e002"
st2 = "models--meta-llama--Llama-3.2-3B/blobs/4719a04514ec2f060240711b7c33ab21187cac730ecaba3040b7a0fd95a9cefb"
t_start_st_load = perf_counter()
tensors1 = {}
with safe_open(HF_HUB_DIR / st1, framework="pt", device="cpu") as f:
for k in f.keys():
tensors1[k] = f.get_tensor(k)
t_st_load = perf_counter() - t_start_st_load
print(f"Tensor 1: torch.load took {t_st_load:.2f} seconds")
t_start_st_load = perf_counter()
tensors2 = {}
with safe_open(HF_HUB_DIR / st2, framework="pt", device="cpu") as f:
for k in f.keys():
tensors2[k] = f.get_tensor(k)
t_st_load = perf_counter() - t_start_st_load
print(f"Tensor 2: torch.load took {t_st_load:.2f} seconds")
pp(list(tensors1.keys()))
breakpoint()
pp(list(tensors2.keys()))
breakpoint()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment