# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the GNU General Public License version 3. from typing import Tuple import os import sys import torch import fire import time import json from pathlib import Path from fairscale.nn.model_parallel.initialize import initialize_model_parallel from llama import ModelArgs, Transformer, Tokenizer, LLaMA def setup_model_parallel() -> Tuple[int, int]: local_rank = int(os.environ.get("LOCAL_RANK", -1)) world_size = int(os.environ.get("WORLD_SIZE", -1)) torch.distributed.init_process_group("nccl") initialize_model_parallel(world_size) torch.cuda.set_device(local_rank) # seed must be the same in all processes torch.manual_seed(1) return local_rank, world_size def load( ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int, max_seq_len: int, max_batch_size: int, ) -> LLaMA: start_time = time.time() checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert world_size == len( checkpoints ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" ckpt_path = checkpoints[local_rank] print("Loading") checkpoint = torch.load(ckpt_path, map_location="cpu") with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) model_args: ModelArgs = ModelArgs( max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params ) tokenizer = Tokenizer(model_path=tokenizer_path) model_args.vocab_size = tokenizer.n_words torch.set_default_tensor_type(torch.cuda.HalfTensor) model = Transformer(model_args) torch.set_default_tensor_type(torch.FloatTensor) model.load_state_dict(checkpoint, strict=False) generator = LLaMA(model, tokenizer) print(f"Loaded in {time.time() - start_time:.2f} seconds") return generator def main( ckpt_dir: str, tokenizer_path: str, temperature: float = 0.9, top_p: float = 0.982, max_seq_len: int = 768, max_batch_size: int = 16, ): local_rank, world_size = setup_model_parallel() if local_rank > 0: sys.stdout = open(os.devnull, "w") generator = load( ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size ) # results = generator.generate( # prompts, max_gen_len=256, temperature=temperature, top_p=top_p # ) seamlessmode = False historymode = False is_first_prompt = True #Interactive mode last_result = "" while True: prompt = input("Enter prompt: ") print("\n") # Re-encode the prompt into: # Me: # AI: if seamlessmode: if is_first_prompt: prompt = "Question: " + prompt + "\nAnswer: " is_first_prompt = False else: prompt = prompt else: prompt = "Question: " + prompt + "\nAnswer: " #concatenate the last result with the prompt prompt = last_result + prompt #trim the prompt to the max sequence length 512 prompt = prompt[-512:] results = generator.generate( [prompt], max_gen_len=768, temperature=temperature, top_p=top_p ) print(results[0]) #concatenate the last result with the new result if seamlessmode: last_result = last_result + results[0] else: last_result = last_result + results[0] + "\n" if not historymode: last_result = "" print("\n") #a dark blue seperator print("\033[1;34;40m" + "-" * 20 + "\033[0m") print("\n") if __name__ == "__main__": fire.Fire(main)