Created
November 16, 2023 23:38
-
-
Save arteagac/5cfb018d605f1cb809fe8c561896f4dc to your computer and use it in GitHub Desktop.
Revisions
-
arteagac created this gist
Nov 16, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,36 @@ # Load the model from transformers import AutoTokenizer, AutoModelForQuestionAnswering import torch tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad") model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad") ## EXPAND POSITION EMBEDDINGS TO 1024 TOKENS max_length = 1024 tokenizer.model_max_length = max_length model.config.max_position_embeddings = max_length model.base_model.embeddings.position_ids = torch.arange(max_length).expand((1, -1)) model.base_model.embeddings.token_type_ids = torch.zeros(max_length).expand((1, -1)) orig_pos_emb = model.base_model.embeddings.position_embeddings.weight model.base_model.embeddings.position_embeddings.weight = torch.nn.Parameter(torch.cat((orig_pos_emb, orig_pos_emb))) ## TEST THE MODEL IN A QUESTION ANSWERING TASK question = "Where is the largest airport in the united states?" # Simulate initial ~600 tokens by repeating 60 times a phrase of length 10 simul_tokens = " ".join(60*["This phrases simulates the initial 600 tokens by simple repetition. "]) # Place the answer to the question at the end of the 600 simulated tokens. context = simul_tokens + "The largest airport in the United States is located in Atlanta." # Use the question answering model inputs = tokenizer(question, context, return_tensors="pt", truncation=True) outputs = model(**inputs) answer_start_index = outputs.start_logits.argmax() answer_end_index = outputs.end_logits.argmax() predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] tokenizer.decode(predict_answer_tokens) # OUTPUT: atlanta # The correct output demonstrates BERT was able to attend beyond 512 tokens # thanks to the expansion in position embeddings.