Last active
November 26, 2024 16:03
-
-
Save anilkeshwani/155b6379ca6a9cc55a237df2de6e0a58 to your computer and use it in GitHub Desktop.
Snippet showing the Llama 3.2 3B checkpoint structure (as an example of the splitting of models by tensor when saving checkpoints to Hugging Face repos; avoid exceeding 5GB max. even if this is not a limit)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| """ | |
| See: https://huggingface.co/docs/safetensors/en/index | |
| """ | |
| from pathlib import Path | |
| from pprint import pp | |
| from time import perf_counter | |
| from safetensors import safe_open | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| t_start = perf_counter() | |
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B") | |
| t_tok = perf_counter() - t_start | |
| print(f"tokenizer loaded in {t_tok:.2f} seconds") | |
| t_1 = perf_counter() | |
| model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B") | |
| t_2 = perf_counter() - t_1 | |
| print(f"model loaded in {t_2:.2f} seconds") | |
| HF_CACHE_DIR = Path().home() / "Desktop" / "huggingface" | |
| HF_HUB_DIR = HF_CACHE_DIR / "hub" | |
| st1 = "models--meta-llama--Llama-3.2-3B/blobs/584d8d3e3f82f7964955174dfe5e3b1cf117a9d859f022cfdf7fcb884856e002" | |
| st2 = "models--meta-llama--Llama-3.2-3B/blobs/4719a04514ec2f060240711b7c33ab21187cac730ecaba3040b7a0fd95a9cefb" | |
| t_start_st_load = perf_counter() | |
| tensors1 = {} | |
| with safe_open(HF_HUB_DIR / st1, framework="pt", device="cpu") as f: | |
| for k in f.keys(): | |
| tensors1[k] = f.get_tensor(k) | |
| t_st_load = perf_counter() - t_start_st_load | |
| print(f"Tensor 1: torch.load took {t_st_load:.2f} seconds") | |
| t_start_st_load = perf_counter() | |
| tensors2 = {} | |
| with safe_open(HF_HUB_DIR / st2, framework="pt", device="cpu") as f: | |
| for k in f.keys(): | |
| tensors2[k] = f.get_tensor(k) | |
| t_st_load = perf_counter() - t_start_st_load | |
| print(f"Tensor 2: torch.load took {t_st_load:.2f} seconds") | |
| pp(list(tensors1.keys())) | |
| breakpoint() | |
| pp(list(tensors2.keys())) | |
| breakpoint() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment