Created
March 4, 2023 09:05
-
-
Save happyme531/8e28502f1829985c3251264fdec8c74e to your computer and use it in GitHub Desktop.
llama interactive example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This software may be used and distributed according to the terms of the GNU General Public License version 3. | |
| from typing import Tuple | |
| import os | |
| import sys | |
| import torch | |
| import fire | |
| import time | |
| import json | |
| from pathlib import Path | |
| from fairscale.nn.model_parallel.initialize import initialize_model_parallel | |
| from llama import ModelArgs, Transformer, Tokenizer, LLaMA | |
| def setup_model_parallel() -> Tuple[int, int]: | |
| local_rank = int(os.environ.get("LOCAL_RANK", -1)) | |
| world_size = int(os.environ.get("WORLD_SIZE", -1)) | |
| torch.distributed.init_process_group("nccl") | |
| initialize_model_parallel(world_size) | |
| torch.cuda.set_device(local_rank) | |
| # seed must be the same in all processes | |
| torch.manual_seed(1) | |
| return local_rank, world_size | |
| def load( | |
| ckpt_dir: str, | |
| tokenizer_path: str, | |
| local_rank: int, | |
| world_size: int, | |
| max_seq_len: int, | |
| max_batch_size: int, | |
| ) -> LLaMA: | |
| start_time = time.time() | |
| checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) | |
| assert world_size == len( | |
| checkpoints | |
| ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" | |
| ckpt_path = checkpoints[local_rank] | |
| print("Loading") | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| with open(Path(ckpt_dir) / "params.json", "r") as f: | |
| params = json.loads(f.read()) | |
| model_args: ModelArgs = ModelArgs( | |
| max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params | |
| ) | |
| tokenizer = Tokenizer(model_path=tokenizer_path) | |
| model_args.vocab_size = tokenizer.n_words | |
| torch.set_default_tensor_type(torch.cuda.HalfTensor) | |
| model = Transformer(model_args) | |
| torch.set_default_tensor_type(torch.FloatTensor) | |
| model.load_state_dict(checkpoint, strict=False) | |
| generator = LLaMA(model, tokenizer) | |
| print(f"Loaded in {time.time() - start_time:.2f} seconds") | |
| return generator | |
| def main( | |
| ckpt_dir: str, | |
| tokenizer_path: str, | |
| temperature: float = 0.9, | |
| top_p: float = 0.982, | |
| max_seq_len: int = 768, | |
| max_batch_size: int = 16, | |
| ): | |
| local_rank, world_size = setup_model_parallel() | |
| if local_rank > 0: | |
| sys.stdout = open(os.devnull, "w") | |
| generator = load( | |
| ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size | |
| ) | |
| # results = generator.generate( | |
| # prompts, max_gen_len=256, temperature=temperature, top_p=top_p | |
| # ) | |
| seamlessmode = False | |
| historymode = False | |
| is_first_prompt = True | |
| #Interactive mode | |
| last_result = "" | |
| while True: | |
| prompt = input("Enter prompt: ") | |
| print("\n") | |
| # Re-encode the prompt into: | |
| # Me: <prompt> | |
| # AI: <empty> | |
| if seamlessmode: | |
| if is_first_prompt: | |
| prompt = "Question: " + prompt + "\nAnswer: " | |
| is_first_prompt = False | |
| else: | |
| prompt = prompt | |
| else: | |
| prompt = "Question: " + prompt + "\nAnswer: " | |
| #concatenate the last result with the prompt | |
| prompt = last_result + prompt | |
| #trim the prompt to the max sequence length 512 | |
| prompt = prompt[-512:] | |
| results = generator.generate( | |
| [prompt], max_gen_len=768, temperature=temperature, top_p=top_p | |
| ) | |
| print(results[0]) | |
| #concatenate the last result with the new result | |
| if seamlessmode: | |
| last_result = last_result + results[0] | |
| else: | |
| last_result = last_result + results[0] + "\n" | |
| if not historymode: | |
| last_result = "" | |
| print("\n") | |
| #a dark blue seperator | |
| print("\033[1;34;40m" + "-" * 20 + "\033[0m") | |
| print("\n") | |
| if __name__ == "__main__": | |
| fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment