Skip to content

Instantly share code, notes, and snippets.

@sitch
Created January 3, 2023 02:31
Show Gist options
  • Select an option

  • Save sitch/6337ea159b13cd6c64726cdbf09b661a to your computer and use it in GitHub Desktop.

Select an option

Save sitch/6337ea159b13cd6c64726cdbf09b661a to your computer and use it in GitHub Desktop.
import time
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
start = time.time()
device = torch.device("cuda")
tokenizer = GPT2Tokenizer.from_pretrained("stanford-crfm/pubmed_gpt_tokenizer")
model = GPT2LMHeadModel.from_pretrained("stanford-crfm/pubmedgpt").to(device)
input_ids = tokenizer.encode(
"Photosynthesis is ", return_tensors="pt"
).to(device)
sample_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=50)
print("Output:\n" + 100 * "-")
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
end = time.time()
print(end - start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment