Skip to content

Instantly share code, notes, and snippets.

@amar-enkhbat
Created June 11, 2021 05:19
Show Gist options
  • Select an option

  • Save amar-enkhbat/646d93c30695a473720c3ceb9fffb213 to your computer and use it in GitHub Desktop.

Select an option

Save amar-enkhbat/646d93c30695a473720c3ceb9fffb213 to your computer and use it in GitHub Desktop.
from transformers import DistilBertModel
from transformers import AlbertTokenizer
from transformers import DistilBertForQuestionAnswering
import torch
# Load model
pt_model = DistilBertModel.from_pretrained('laboro-ai/distilbert-base-japanese')
sp_tokenizer = AlbertTokenizer.from_pretrained('laboro-ai/distilbert-base-japanese')
model = DistilBertForQuestionAnswering.from_pretrained('laboro-ai/distilbert-base-japanese-finetuned-ddqa')
# Input data
question, context = "ジムとはだれ?", "俺は猫が大好きだ。わからない。いいこと言いますね。ジムはいいひとだ"
# Tokenize inputs
inputs = sp_tokenizer(question, context, return_tensors='pt')
input_ids = inputs["input_ids"].tolist()[0]
inputs.pop('token_type_ids')
# Run model
model.eval()
answer_start_scores, answer_end_scores = model(**inputs, return_dict=False)
answer_start = torch.argmax(answer_start_scores)
answer_end = torch.argmax(answer_end_scores) + 1
answer = sp_tokenizer.convert_tokens_to_string(sp_tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
# Results
print("Context:")
print(context)
print("")
print("Question:")
print(question)
print("")
print("Answer:")
print(answer)
###### Output #########
Context:
俺は猫が大好きだ。わからない。いいこと言いますね。ジムはいいひとだ
Question:
ジムとはだれ?
Answer:
いいひと
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment