Created
June 11, 2021 05:19
-
-
Save amar-enkhbat/646d93c30695a473720c3ceb9fffb213 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from transformers import DistilBertModel | |
| from transformers import AlbertTokenizer | |
| from transformers import DistilBertForQuestionAnswering | |
| import torch | |
| # Load model | |
| pt_model = DistilBertModel.from_pretrained('laboro-ai/distilbert-base-japanese') | |
| sp_tokenizer = AlbertTokenizer.from_pretrained('laboro-ai/distilbert-base-japanese') | |
| model = DistilBertForQuestionAnswering.from_pretrained('laboro-ai/distilbert-base-japanese-finetuned-ddqa') | |
| # Input data | |
| question, context = "ジムとはだれ?", "俺は猫が大好きだ。わからない。いいこと言いますね。ジムはいいひとだ" | |
| # Tokenize inputs | |
| inputs = sp_tokenizer(question, context, return_tensors='pt') | |
| input_ids = inputs["input_ids"].tolist()[0] | |
| inputs.pop('token_type_ids') | |
| # Run model | |
| model.eval() | |
| answer_start_scores, answer_end_scores = model(**inputs, return_dict=False) | |
| answer_start = torch.argmax(answer_start_scores) | |
| answer_end = torch.argmax(answer_end_scores) + 1 | |
| answer = sp_tokenizer.convert_tokens_to_string(sp_tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])) | |
| # Results | |
| print("Context:") | |
| print(context) | |
| print("") | |
| print("Question:") | |
| print(question) | |
| print("") | |
| print("Answer:") | |
| print(answer) | |
| ###### Output ######### | |
| Context: | |
| 俺は猫が大好きだ。わからない。いいこと言いますね。ジムはいいひとだ | |
| Question: | |
| ジムとはだれ? | |
| Answer: | |
| いいひと |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment