Skip to content

Instantly share code, notes, and snippets.

@arteagac
Created November 16, 2023 23:38
Show Gist options
  • Select an option

  • Save arteagac/5cfb018d605f1cb809fe8c561896f4dc to your computer and use it in GitHub Desktop.

Select an option

Save arteagac/5cfb018d605f1cb809fe8c561896f4dc to your computer and use it in GitHub Desktop.
Expand BERT beyond 512 tokens
# Load the model
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
## EXPAND POSITION EMBEDDINGS TO 1024 TOKENS
max_length = 1024
tokenizer.model_max_length = max_length
model.config.max_position_embeddings = max_length
model.base_model.embeddings.position_ids = torch.arange(max_length).expand((1, -1))
model.base_model.embeddings.token_type_ids = torch.zeros(max_length).expand((1, -1))
orig_pos_emb = model.base_model.embeddings.position_embeddings.weight
model.base_model.embeddings.position_embeddings.weight = torch.nn.Parameter(torch.cat((orig_pos_emb, orig_pos_emb)))
## TEST THE MODEL IN A QUESTION ANSWERING TASK
question = "Where is the largest airport in the united states?"
# Simulate initial ~600 tokens by repeating 60 times a phrase of length 10
simul_tokens = " ".join(60*["This phrases simulates the initial 600 tokens by simple repetition. "])
# Place the answer to the question at the end of the 600 simulated tokens.
context = simul_tokens + "The largest airport in the United States is located in Atlanta."
# Use the question answering model
inputs = tokenizer(question, context, return_tensors="pt", truncation=True)
outputs = model(**inputs)
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
tokenizer.decode(predict_answer_tokens)
# OUTPUT: atlanta
# The correct output demonstrates BERT was able to attend beyond 512 tokens
# thanks to the expansion in position embeddings.
@deshwalmahesh
Copy link
Copy Markdown

deshwalmahesh commented Nov 27, 2023

Anyone looking at this, you can check with this to test that nuking base pre trained model with 1024 gives you error: RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 1

BUT, if you change the model to "microsoft/deberta-v3-base", and use even 8096, it won't give you errors because of the type of attention and Pos Emb used

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch

tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")


## TEST THE MODEL IN A QUESTION ANSWERING TASK
question = "Where is the largest airport in the united states?"

# Simulate initial ~600 tokens by repeating 60 times a phrase of length 10
simul_tokens = " ".join(60*["This phrases simulates the initial 600 tokens by simple repetition. "])
# Place the answer to the question at the end of the 600 simulated tokens.
context = simul_tokens + "The largest airport in the United States is located in Atlanta."

# Use the question answering model
inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length = 1024, padding = "max_length")
outputs = model(**inputs)    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment