Skip to content

Instantly share code, notes, and snippets.

@akashyanpure
Forked from ramsrigouthamg/dialoGPT.py
Created February 27, 2021 20:23
Show Gist options
  • Select an option

  • Save akashyanpure/2092151982672af9b540bd4663c3a573 to your computer and use it in GitHub Desktop.

Select an option

Save akashyanpure/2092151982672af9b540bd4663c3a573 to your computer and use it in GitHub Desktop.
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelWithLMHead.from_pretrained("microsoft/DialoGPT-medium")
sentence_list = [
"I heard you won the cricket match.",
"I did!",
"Awesome. Who did you play against?",
"I played against the Aussies.",
"Wow ! Was it a tough game?",
"It was a tough game. It went on till the last over. They almost won.",
"Where was the match?"
]
all_sentences_string =""
for sentence in sentence_list:
all_sentences_string = all_sentences_string+sentence+tokenizer.eos_token
print ("All sentences concatenated with EOS token:\n")
print (all_sentences_string)
tokenized_all_sentences_string = tokenizer.encode(all_sentences_string, return_tensors='pt')
reply_predicted = model.generate(tokenized_all_sentences_string, max_length=1000)
prefix_length = tokenized_all_sentences_string.shape[-1]
decoded_reply_predicted_with_input = tokenizer.decode(reply_predicted[0], skip_special_tokens=True)
decoded_reply_predicted = tokenizer.decode(reply_predicted[:,prefix_length:][0], skip_special_tokens=True)
print ("\n\nPredicted reply along with initial input: ")
print (decoded_reply_predicted_with_input)
print ("\n\nPredicted reply: ")
print (decoded_reply_predicted)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment