Last active
November 27, 2019 03:40
-
-
Save kanjirz50/1752aa41ad18a0845e6855e1dd102488 to your computer and use it in GitHub Desktop.
ニュースコーパスから学習されたBERTモデルを動かす。公開されているストックマーク株式会社の森長さまに感謝。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# [大規模日本語ビジネスニュースコーパスを学習したBERT事前学習済(MeCab利用)モデルの紹介](https://qiita.com/mkt3/items/3c1278339ff1bcc0187f)を動かす" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T00:54:45.075324Z", | |
| "start_time": "2019-04-11T00:54:45.069771Z" | |
| } | |
| }, | |
| "source": [ | |
| "### 事前準備\n", | |
| "[ダウンロードリンク](https://drive.google.com/open?id=1iDlmhGgJ54rkVBtZvgMlgbuNwtFQ50V-) から PyTorch 版のファイルをダウンロードする。\n", | |
| "\n", | |
| "次の3つのファイルが直下に含まれる `tar.gz` アーカイブを作成する。\n", | |
| "- bert_config.json\n", | |
| "- pytorch_model.bin\n", | |
| "- vocab.txt\n", | |
| "\n", | |
| "vocab.txtはそのままでも必要。\n", | |
| "\n", | |
| "PyTorchでBERTを扱うインターフェイスのモジュールをインストールしておく。\n", | |
| "`pip install pytorch-pretrained-bert`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:52.274037Z", | |
| "start_time": "2019-04-11T01:15:51.727863Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import collections\n", | |
| "import logging\n", | |
| "import os\n", | |
| "\n", | |
| "import torch\n", | |
| "import MeCab\n", | |
| "from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, WordpieceTokenizer\n", | |
| "from pytorch_pretrained_bert.tokenization import load_vocab" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:52.891854Z", | |
| "start_time": "2019-04-11T01:15:52.888981Z" | |
| }, | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "logging.basicConfig(level=logging.INFO)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-09T06:29:11.357348Z", | |
| "start_time": "2019-04-09T06:29:11.354945Z" | |
| } | |
| }, | |
| "source": [ | |
| "辞書の準備を行う。\n", | |
| "\n", | |
| "`DICDIR` には MeCab IPADic Neologdを指定する。\n", | |
| "\n", | |
| "`USERDIC` には 次の項目を記述したCSVファイルを作成し、コンパイルする。\n", | |
| "理由はBERTで利用する特別なタグは1単語としたいためである。\n", | |
| "\n", | |
| "```\n", | |
| "[UNK],1285,1285,100,名詞,記号,*,*,*,*,[UNK],[UNK],[UNK]\n", | |
| "[SEP],1285,1285,100,名詞,記号,*,*,*,*,[SEP],[SEP],[SEP]\n", | |
| "[PAD],1285,1285,100,名詞,記号,*,*,*,*,[PAD],[PAD],[PAD]\n", | |
| "[CLS],1285,1285,100,名詞,記号,*,*,*,*,[CLS],[CLS],[CLS]\n", | |
| "[MASK],1285,1285,100,名詞,記号,*,*,*,*,[MASK],[MASK],[MASK]\n", | |
| "```\n", | |
| "\n", | |
| "コンパイルはこんな感じ `/usr/local/libexec/mecab/mecab-dict-index -d /path/to/mecab-ipadic-neologd/build/mecab-ipadic-2.7.0-20070801-neologd-20180308 -u user.dic -f utf-8 -t utf-8 作った辞書.csv`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "BERT のインターフェースを提供している [PyTorch-pretorained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT) を利用する。\n", | |
| "日本語ビジネスニュースで学習済みのBERTを利用するためにいくつか修正が必要である。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:54.285906Z", | |
| "start_time": "2019-04-11T01:15:54.278355Z" | |
| }, | |
| "code_folding": [ | |
| 4 | |
| ], | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# BertTokenizerで利用可能な日本語Tokenizerを用意する\n", | |
| "DICDIR = \"/path/to/mecab-ipadic-neologd/build/mecab-ipadic-2.7.0-20070801-neologd-20180308\"\n", | |
| "USERDIC = \"/path/to/user.dic\"\n", | |
| "\n", | |
| "class MeCabBert:\n", | |
| " def __init__(self, dicdir=DICDIR, userdic=USERDIC):\n", | |
| " self.tagger = MeCab.Tagger(f\"-d {dicdir} -b 100000 -Owakati -u {userdic}\")\n", | |
| " self.tagger.parse(\"\")\n", | |
| "\n", | |
| " def tokenize(self, text):\n", | |
| " return self.tagger.parse(text).split()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:54.345452Z", | |
| "start_time": "2019-04-11T01:15:54.309959Z" | |
| }, | |
| "code_folding": [ | |
| 1 | |
| ], | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# BertTokenizerを継承して、MeCabBertで動くように修正する\n", | |
| "class BertMeCabTokenizer(BertTokenizer):\n", | |
| " def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,\n", | |
| " never_split=(\"[UNK]\", \"[SEP]\", \"[PAD]\", \"[CLS]\", \"[MASK]\")):\n", | |
| " \"\"\"Constructs a BertTokenizer.\n", | |
| " Args:\n", | |
| " vocab_file: Path to a one-wordpiece-per-line vocabulary file\n", | |
| " do_lower_case: Whether to lower case the input\n", | |
| " Only has an effect when do_wordpiece_only=False\n", | |
| " do_basic_tokenize: Whether to do basic tokenization before wordpiece.\n", | |
| " max_len: An artificial maximum length to truncate tokenized sequences to;\n", | |
| " Effective maximum length is always the minimum of this\n", | |
| " value (if specified) and the underlying BERT model's\n", | |
| " sequence length.\n", | |
| " never_split: List of tokens which will never be split during tokenization.\n", | |
| " Only has an effect when do_wordpiece_only=False\n", | |
| " \"\"\"\n", | |
| " if not os.path.isfile(vocab_file):\n", | |
| " raise ValueError(\n", | |
| " \"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained \"\n", | |
| " \"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`\".format(vocab_file))\n", | |
| " self.vocab = load_vocab(vocab_file)\n", | |
| " self.ids_to_tokens = collections.OrderedDict(\n", | |
| " [(ids, tok) for tok, ids in self.vocab.items()])\n", | |
| " self.do_basic_tokenize = do_basic_tokenize\n", | |
| " if do_basic_tokenize:\n", | |
| " # MeCabBertに変更\n", | |
| " self.basic_tokenizer = MeCabBert()\n", | |
| " self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)\n", | |
| " self.max_len = max_len if max_len is not None else int(1e12)\n", | |
| "\n", | |
| " def convert_tokens_to_ids(self, tokens):\n", | |
| " \"\"\"Converts a sequence of tokens into ids using the vocab.\"\"\"\n", | |
| " ids = []\n", | |
| " for token in tokens:\n", | |
| " # SentencepieceやSubwordを使わないので、未知語は[UNK]のidである1を返す\n", | |
| " ids.append(self.vocab.get(token, 1))\n", | |
| " if len(ids) > self.max_len:\n", | |
| " logger.warning(\n", | |
| " \"Token indices sequence length is longer than the specified maximum \"\n", | |
| " \" sequence length for this BERT model ({} > {}). Running this\"\n", | |
| " \" sequence through BERT will result in indexing errors\".format(len(ids), self.max_len)\n", | |
| " )\n", | |
| " return ids" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Tokenizer" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:55.292660Z", | |
| "start_time": "2019-04-11T01:15:55.214018Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file ./PyTorchVer/vocab.txt\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# vocab.txtを指定\n", | |
| "tokenizer = BertMeCabTokenizer.from_pretrained('./PyTorchVer/vocab.txt')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:55.300598Z", | |
| "start_time": "2019-04-11T01:15:55.294956Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "['[CLS]', '日本', 'で', '有名', 'な', '銀行', 'は', 'どこ', 'です', 'か', '?', '[SEP]', 'それ', 'は', 'みずほ銀行', 'です', '。', '[SEP]']\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "text = \"[CLS] 日本で有名な銀行はどこですか? [SEP] それはみずほ銀行です。 [SEP]\"\n", | |
| "tokenized_text = tokenizer.tokenize(text)\n", | |
| "print(tokenized_text)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:55.493018Z", | |
| "start_time": "2019-04-11T01:15:55.488225Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "['[CLS]', '日本', 'で', '有名', 'な', '[MASK]', 'は', 'どこ', 'です', 'か', '?', '[SEP]', 'それ', 'は', 'みずほ銀行', 'です', '。', '[SEP]']\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# 例えば、穴埋めタスクを解く場合、[MASK]に置換する\n", | |
| "masked_index = 5\n", | |
| "tokenized_text[masked_index] = '[MASK]'\n", | |
| "print(tokenized_text)\n", | |
| "\n", | |
| "assert tokenized_text == ['[CLS]', '日本', 'で', '有名', 'な', '[MASK]', 'は', 'どこ', 'です', 'か', '?', '[SEP]', 'それ', 'は', 'みずほ銀行', 'です', '。', '[SEP]']" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:55.744021Z", | |
| "start_time": "2019-04-11T01:15:55.741324Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# TokenをID列に変換\n", | |
| "indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:55.793482Z", | |
| "start_time": "2019-04-11T01:15:55.788153Z" | |
| }, | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "SEP_ID = 3\n", | |
| "def get_segments_ids(indexed_tokens):\n", | |
| " segments_ids = []\n", | |
| " i = 0\n", | |
| " for indexed_token in indexed_tokens:\n", | |
| " segments_ids.append(i)\n", | |
| " if indexed_token == SEP_ID:\n", | |
| " i = i + 1\n", | |
| " return segments_ids" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:55.808830Z", | |
| "start_time": "2019-04-11T01:15:55.806657Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "segments_ids = get_segments_ids(indexed_tokens)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:55.859403Z", | |
| "start_time": "2019-04-11T01:15:55.856558Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print(segments_ids)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:15:56.567868Z", | |
| "start_time": "2019-04-11T01:15:56.564388Z" | |
| }, | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "tokens_tensor = torch.tensor([indexed_tokens])\n", | |
| "segments_tensors = torch.tensor([segments_ids])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### モデルを読み込んで、層の数を確認" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:16:05.358798Z", | |
| "start_time": "2019-04-11T01:15:56.853471Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:pytorch_pretrained_bert.modeling:loading archive file ./pytorchver.tar.gz\n", | |
| "INFO:pytorch_pretrained_bert.modeling:extracting archive file ./pytorchver.tar.gz to temp dir /tmp/tmpsnvwexxt\n", | |
| "INFO:pytorch_pretrained_bert.modeling:Model config {\n", | |
| " \"attention_probs_dropout_prob\": 0.1,\n", | |
| " \"hidden_act\": \"gelu\",\n", | |
| " \"hidden_dropout_prob\": 0.1,\n", | |
| " \"hidden_size\": 768,\n", | |
| " \"initializer_range\": 0.02,\n", | |
| " \"intermediate_size\": 3072,\n", | |
| " \"max_position_embeddings\": 512,\n", | |
| " \"num_attention_heads\": 12,\n", | |
| " \"num_hidden_layers\": 12,\n", | |
| " \"type_vocab_size\": 2,\n", | |
| " \"vocab_size\": 32005\n", | |
| "}\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "BertModel(\n", | |
| " (embeddings): BertEmbeddings(\n", | |
| " (word_embeddings): Embedding(32005, 768)\n", | |
| " (position_embeddings): Embedding(512, 768)\n", | |
| " (token_type_embeddings): Embedding(2, 768)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (encoder): BertEncoder(\n", | |
| " (layer): ModuleList(\n", | |
| " (0): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (1): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (2): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (3): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (4): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (5): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (6): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (7): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (8): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (9): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (10): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (11): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (pooler): BertPooler(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (activation): Tanh()\n", | |
| " )\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model = BertModel.from_pretrained('./pytorchver.tar.gz')\n", | |
| "model.eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:16:05.428229Z", | |
| "start_time": "2019-04-11T01:16:05.360688Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "with torch.no_grad():\n", | |
| " encoded_layers, _ = model(tokens_tensor, segments_tensors)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:16:05.592271Z", | |
| "start_time": "2019-04-11T01:16:05.589357Z" | |
| }, | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "assert len(encoded_layers) == 12" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### モデルを読み込んで、穴埋めタスクを解く" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:16:14.059586Z", | |
| "start_time": "2019-04-11T01:16:05.671871Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:pytorch_pretrained_bert.modeling:loading archive file ./pytorchver.tar.gz\n", | |
| "INFO:pytorch_pretrained_bert.modeling:extracting archive file ./pytorchver.tar.gz to temp dir /tmp/tmpc5utllq6\n", | |
| "INFO:pytorch_pretrained_bert.modeling:Model config {\n", | |
| " \"attention_probs_dropout_prob\": 0.1,\n", | |
| " \"hidden_act\": \"gelu\",\n", | |
| " \"hidden_dropout_prob\": 0.1,\n", | |
| " \"hidden_size\": 768,\n", | |
| " \"initializer_range\": 0.02,\n", | |
| " \"intermediate_size\": 3072,\n", | |
| " \"max_position_embeddings\": 512,\n", | |
| " \"num_attention_heads\": 12,\n", | |
| " \"num_hidden_layers\": 12,\n", | |
| " \"type_vocab_size\": 2,\n", | |
| " \"vocab_size\": 32005\n", | |
| "}\n", | |
| "\n", | |
| "INFO:pytorch_pretrained_bert.modeling:Weights from pretrained model not used in BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "BertForMaskedLM(\n", | |
| " (bert): BertModel(\n", | |
| " (embeddings): BertEmbeddings(\n", | |
| " (word_embeddings): Embedding(32005, 768)\n", | |
| " (position_embeddings): Embedding(512, 768)\n", | |
| " (token_type_embeddings): Embedding(2, 768)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (encoder): BertEncoder(\n", | |
| " (layer): ModuleList(\n", | |
| " (0): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (1): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (2): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (3): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (4): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (5): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (6): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (7): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (8): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (9): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (10): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (11): BertLayer(\n", | |
| " (attention): BertAttention(\n", | |
| " (self): BertSelfAttention(\n", | |
| " (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " (output): BertSelfOutput(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " (intermediate): BertIntermediate(\n", | |
| " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
| " )\n", | |
| " (output): BertOutput(\n", | |
| " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " (dropout): Dropout(p=0.1)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (pooler): BertPooler(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (activation): Tanh()\n", | |
| " )\n", | |
| " )\n", | |
| " (cls): BertOnlyMLMHead(\n", | |
| " (predictions): BertLMPredictionHead(\n", | |
| " (transform): BertPredictionHeadTransform(\n", | |
| " (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
| " (LayerNorm): BertLayerNorm()\n", | |
| " )\n", | |
| " (decoder): Linear(in_features=768, out_features=32005, bias=False)\n", | |
| " )\n", | |
| " )\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model = BertForMaskedLM.from_pretrained('./pytorchver.tar.gz')\n", | |
| "model.eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:16:14.127765Z", | |
| "start_time": "2019-04-11T01:16:14.063117Z" | |
| }, | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "with torch.no_grad():\n", | |
| " predictions = model(tokens_tensor, segments_tensors)\n", | |
| "\n", | |
| "# confirm we were able to predict 'henson'\n", | |
| "predicted_index = torch.argmax(predictions[0, masked_index]).item()\n", | |
| "predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:16:14.133900Z", | |
| "start_time": "2019-04-11T01:16:14.130762Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "銀行\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# 「銀行」をどんぴしゃで当ててくる\n", | |
| "print(predicted_token)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:16:14.154354Z", | |
| "start_time": "2019-04-11T01:16:14.150858Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[UNK]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print(tokenizer.convert_ids_to_tokens([1])[0])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-04-11T01:16:14.180584Z", | |
| "start_time": "2019-04-11T01:16:14.170528Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1 銀行\n", | |
| "2 企業\n", | |
| "3 [UNK]\n", | |
| "4 金融機関\n", | |
| "5 地方銀行\n", | |
| "6 ところ\n", | |
| "7 会社\n", | |
| "8 地銀\n", | |
| "9 国\n", | |
| "10 の\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# 確率上位TOP10の語を出力する\n", | |
| "topn = 10\n", | |
| "for i, idx in enumerate(torch.argsort(predictions[0, masked_index], descending=True)[:topn], start=1):\n", | |
| " print(i, tokenizer.convert_ids_to_tokens([int(idx)])[0])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "hide_input": false, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.3" | |
| }, | |
| "toc": { | |
| "nav_menu": {}, | |
| "number_sections": true, | |
| "sideBar": true, | |
| "skip_h1_title": false, | |
| "toc_cell": false, | |
| "toc_position": {}, | |
| "toc_section_display": "block", | |
| "toc_window_display": false | |
| }, | |
| "varInspector": { | |
| "cols": { | |
| "lenName": 16, | |
| "lenType": 16, | |
| "lenVar": 40 | |
| }, | |
| "kernels_config": { | |
| "python": { | |
| "delete_cmd_postfix": "", | |
| "delete_cmd_prefix": "del ", | |
| "library": "var_list.py", | |
| "varRefreshCmd": "print(var_dic_list())" | |
| }, | |
| "r": { | |
| "delete_cmd_postfix": ") ", | |
| "delete_cmd_prefix": "rm(", | |
| "library": "var_list.r", | |
| "varRefreshCmd": "cat(var_dic_list()) " | |
| } | |
| }, | |
| "types_to_exclude": [ | |
| "module", | |
| "function", | |
| "builtin_function_or_method", | |
| "instance", | |
| "_Feature" | |
| ], | |
| "window_display": false | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment