Skip to content

Instantly share code, notes, and snippets.

@happyme531
Created March 4, 2023 09:05
Show Gist options
  • Select an option

  • Save happyme531/8e28502f1829985c3251264fdec8c74e to your computer and use it in GitHub Desktop.

Select an option

Save happyme531/8e28502f1829985c3251264fdec8c74e to your computer and use it in GitHub Desktop.
llama interactive example
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
from pathlib import Path
from fairscale.nn.model_parallel.initialize import initialize_model_parallel
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
def setup_model_parallel() -> Tuple[int, int]:
local_rank = int(os.environ.get("LOCAL_RANK", -1))
world_size = int(os.environ.get("WORLD_SIZE", -1))
torch.distributed.init_process_group("nccl")
initialize_model_parallel(world_size)
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(1)
return local_rank, world_size
def load(
ckpt_dir: str,
tokenizer_path: str,
local_rank: int,
world_size: int,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert world_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
print("Loading")
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)
generator = LLaMA(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.9,
top_p: float = 0.982,
max_seq_len: int = 768,
max_batch_size: int = 16,
):
local_rank, world_size = setup_model_parallel()
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
generator = load(
ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
)
# results = generator.generate(
# prompts, max_gen_len=256, temperature=temperature, top_p=top_p
# )
seamlessmode = False
historymode = False
is_first_prompt = True
#Interactive mode
last_result = ""
while True:
prompt = input("Enter prompt: ")
print("\n")
# Re-encode the prompt into:
# Me: <prompt>
# AI: <empty>
if seamlessmode:
if is_first_prompt:
prompt = "Question: " + prompt + "\nAnswer: "
is_first_prompt = False
else:
prompt = prompt
else:
prompt = "Question: " + prompt + "\nAnswer: "
#concatenate the last result with the prompt
prompt = last_result + prompt
#trim the prompt to the max sequence length 512
prompt = prompt[-512:]
results = generator.generate(
[prompt], max_gen_len=768, temperature=temperature, top_p=top_p
)
print(results[0])
#concatenate the last result with the new result
if seamlessmode:
last_result = last_result + results[0]
else:
last_result = last_result + results[0] + "\n"
if not historymode:
last_result = ""
print("\n")
#a dark blue seperator
print("\033[1;34;40m" + "-" * 20 + "\033[0m")
print("\n")
if __name__ == "__main__":
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment