-
-
Save akashyanpure/2092151982672af9b540bd4663c3a573 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from transformers import AutoModelWithLMHead, AutoTokenizer | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") | |
| model = AutoModelWithLMHead.from_pretrained("microsoft/DialoGPT-medium") | |
| sentence_list = [ | |
| "I heard you won the cricket match.", | |
| "I did!", | |
| "Awesome. Who did you play against?", | |
| "I played against the Aussies.", | |
| "Wow ! Was it a tough game?", | |
| "It was a tough game. It went on till the last over. They almost won.", | |
| "Where was the match?" | |
| ] | |
| all_sentences_string ="" | |
| for sentence in sentence_list: | |
| all_sentences_string = all_sentences_string+sentence+tokenizer.eos_token | |
| print ("All sentences concatenated with EOS token:\n") | |
| print (all_sentences_string) | |
| tokenized_all_sentences_string = tokenizer.encode(all_sentences_string, return_tensors='pt') | |
| reply_predicted = model.generate(tokenized_all_sentences_string, max_length=1000) | |
| prefix_length = tokenized_all_sentences_string.shape[-1] | |
| decoded_reply_predicted_with_input = tokenizer.decode(reply_predicted[0], skip_special_tokens=True) | |
| decoded_reply_predicted = tokenizer.decode(reply_predicted[:,prefix_length:][0], skip_special_tokens=True) | |
| print ("\n\nPredicted reply along with initial input: ") | |
| print (decoded_reply_predicted_with_input) | |
| print ("\n\nPredicted reply: ") | |
| print (decoded_reply_predicted) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment