Skip to content

Instantly share code, notes, and snippets.

@arteagac
Created November 16, 2023 23:38
Show Gist options
  • Select an option

  • Save arteagac/5cfb018d605f1cb809fe8c561896f4dc to your computer and use it in GitHub Desktop.

Select an option

Save arteagac/5cfb018d605f1cb809fe8c561896f4dc to your computer and use it in GitHub Desktop.

Revisions

  1. arteagac created this gist Nov 16, 2023.
    36 changes: 36 additions & 0 deletions expand_bert_beyond_512_tokens.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,36 @@
    # Load the model
    from transformers import AutoTokenizer, AutoModelForQuestionAnswering
    import torch

    tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
    model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")

    ## EXPAND POSITION EMBEDDINGS TO 1024 TOKENS
    max_length = 1024
    tokenizer.model_max_length = max_length
    model.config.max_position_embeddings = max_length
    model.base_model.embeddings.position_ids = torch.arange(max_length).expand((1, -1))
    model.base_model.embeddings.token_type_ids = torch.zeros(max_length).expand((1, -1))
    orig_pos_emb = model.base_model.embeddings.position_embeddings.weight
    model.base_model.embeddings.position_embeddings.weight = torch.nn.Parameter(torch.cat((orig_pos_emb, orig_pos_emb)))

    ## TEST THE MODEL IN A QUESTION ANSWERING TASK
    question = "Where is the largest airport in the united states?"

    # Simulate initial ~600 tokens by repeating 60 times a phrase of length 10
    simul_tokens = " ".join(60*["This phrases simulates the initial 600 tokens by simple repetition. "])
    # Place the answer to the question at the end of the 600 simulated tokens.
    context = simul_tokens + "The largest airport in the United States is located in Atlanta."

    # Use the question answering model
    inputs = tokenizer(question, context, return_tensors="pt", truncation=True)
    outputs = model(**inputs)
    answer_start_index = outputs.start_logits.argmax()
    answer_end_index = outputs.end_logits.argmax()
    predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
    tokenizer.decode(predict_answer_tokens)

    # OUTPUT: atlanta

    # The correct output demonstrates BERT was able to attend beyond 512 tokens
    # thanks to the expansion in position embeddings.