Last active
March 27, 2024 12:50
-
-
Save BestLemoon/8f1516947f1f7bd8e4f5fc0689316810 to your computer and use it in GitHub Desktop.
My_Handmade_Transformer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/BestLemoon/8f1516947f1f7bd8e4f5fc0689316810/transformer_learning-1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "JM3QHfPzh--b", | |
| "cell_id": "b848c37ce6eb4dae914a16b471722c04", | |
| "deepnote_cell_type": "markdown" | |
| }, | |
| "source": [ | |
| "# Embeddings" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "source_hash": "6bd170bb", | |
| "execution_start": 1711541713717, | |
| "execution_millis": 505, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "bb423e32ffbb462281fb21daeb7b1a09", | |
| "deepnote_cell_type": "code", | |
| "id": "eZnMSJ753p0p" | |
| }, | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F\n" | |
| ], | |
| "execution_count": 1, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "sqXjgw_oiBVn", | |
| "source_hash": "426b3b40", | |
| "execution_start": 1711541714229, | |
| "execution_millis": 96, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "e675994c53664b3095bcf12e02ffd511", | |
| "deepnote_cell_type": "code" | |
| }, | |
| "source": [ | |
| "class TokenEmbedding(nn.Embedding):\n", | |
| " def __init__(self, vocab_size, d_model):\n", | |
| " super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)\n" | |
| ], | |
| "execution_count": 2, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Y3obgd9xjlmR", | |
| "source_hash": "7fea523b", | |
| "execution_start": 1711541714237, | |
| "execution_millis": 89, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "d83a8689d2c74f08b888fef30d2bae3b", | |
| "deepnote_cell_type": "code" | |
| }, | |
| "source": [ | |
| "class PositionalEmbedding(nn.Module):\n", | |
| " def __init__(self,d_model,max_len,device):\n", | |
| " super(PositionalEmbedding,self).__init__()\n", | |
| " self.pe=torch.zeros(max_len,d_model,device=device) # (max_len,d_model)\n", | |
| " self.pe.requires_grad_(False) # no grad bc no change to pe during transformer\n", | |
| " pos = torch.arange(0, max_len, dtype=torch.float, device=device).unsqueeze(1) # (max_len, 1)\n", | |
| " _2i = torch.arange(0, d_model, step=2, dtype=torch.float, device=device) # (d_model/2,)\n", | |
| "\n", | |
| " self.pe[:,0::2]=torch.sin(pos / (10000 ** (_2i / d_model))) # (max_len,d_model/2)\n", | |
| " self.pe[:,1::2]=torch.cos(pos / (10000 ** (_2i / d_model))) # (max_len,d_model/2)\n", | |
| " def forward(self,x):\n", | |
| " seq_len=x.size(1)\n", | |
| " return self.pe[:seq_len,:]" | |
| ], | |
| "execution_count": 3, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "hIyDb3Z-uX0X", | |
| "source_hash": "dd57df35", | |
| "execution_start": 1711541714258, | |
| "execution_millis": 101, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "f34c0740aebb4050a5d97ae600ef0e91", | |
| "deepnote_cell_type": "code" | |
| }, | |
| "source": [ | |
| "class TotalEmbedding(nn.Module):\n", | |
| " def __init__(self, d_model, max_len, vocab_size, device, drop_prob):\n", | |
| " super(TotalEmbedding, self).__init__()\n", | |
| " self.TokenEmbedding = TokenEmbedding(vocab_size, d_model)\n", | |
| " self.PositionalEmbedding = PositionalEmbedding(d_model, max_len, device)\n", | |
| " self.dropout = nn.Dropout(p=drop_prob)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " # x:(batch_size, seq_len)\n", | |
| " token = self.TokenEmbedding(x) # (batch_size, seq_len, d_model)\n", | |
| " pe = self.PositionalEmbedding(x) # (seq_len,d_model)\n", | |
| " return self.dropout(token + pe)\n" | |
| ], | |
| "execution_count": 4, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "nlnS6PqJBADg", | |
| "cell_id": "a8a61941bdb4420c8fd66f84e65ad44d", | |
| "deepnote_cell_type": "markdown" | |
| }, | |
| "source": [ | |
| "# LayerNorm" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "jaN9Oauvx5MG", | |
| "source_hash": "2e7186ce", | |
| "execution_start": 1711541714259, | |
| "execution_millis": 101, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "0be27f0e5ce341f19e49ee91163a7cb6", | |
| "deepnote_cell_type": "code" | |
| }, | |
| "source": [ | |
| "class LayerNorm(nn.Module):\n", | |
| " def __init__(self, d_model, eps=1e-10):\n", | |
| " super(LayerNorm, self).__init__()\n", | |
| " self.gemma = nn.Parameter(torch.ones(d_model))\n", | |
| " self.beta = nn.Parameter(torch.zeros(d_model))\n", | |
| " self.eps = eps\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " mean = x.mean(-1, keepdim=True) # 保持维度方便计算\n", | |
| " var = x.var(-1, keepdim=True, unbiased=False) # 使用有偏估计\n", | |
| " out = (x - mean) / torch.sqrt(var + self.eps)\n", | |
| " out = out * self.gemma + self.beta\n", | |
| " return out\n" | |
| ], | |
| "execution_count": 5, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "ogKEyTefBBmJ", | |
| "cell_id": "7b70e0a3e0e845babb19631b73d1fb8e", | |
| "deepnote_cell_type": "markdown" | |
| }, | |
| "source": [ | |
| "# FFN" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "nWZjqG_b_tTe", | |
| "source_hash": "b1721ee9", | |
| "execution_start": 1711541714260, | |
| "execution_millis": 100, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "12596bebdd504d369bb5648ddc32bfe0", | |
| "deepnote_cell_type": "code" | |
| }, | |
| "source": [ | |
| "class FFN(nn.Module):\n", | |
| " def __init__(self, d_model, hidden_size, drop_out=0.1):\n", | |
| " super(FFN, self).__init__()\n", | |
| " self.fc1 = nn.Linear(d_model, hidden_size)\n", | |
| " self.fc2 = nn.Linear(hidden_size, d_model)\n", | |
| " self.dropout = nn.Dropout(drop_out)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " x = F.relu(self.fc1(x))\n", | |
| " out = self.fc2(self.dropout(x))\n", | |
| " return out\n" | |
| ], | |
| "execution_count": 6, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "CE-wnzEwBDix", | |
| "cell_id": "0bfee901eaf148ea8e695164b7ce56e6", | |
| "deepnote_cell_type": "markdown" | |
| }, | |
| "source": [ | |
| "# MultiHead-Attention" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "hQmszCuyQAnH", | |
| "source_hash": "c3c12d3b", | |
| "execution_start": 1711541714263, | |
| "execution_millis": 120, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "6a46cc717d9049f5b41ea37d263c3859", | |
| "deepnote_cell_type": "code" | |
| }, | |
| "source": [ | |
| "class MultiHeadAttention(nn.Module):\n", | |
| " def __init__(self, d_model, n_head):\n", | |
| " super(MultiHeadAttention, self).__init__()\n", | |
| " self.d_model = d_model\n", | |
| " self.n_head = n_head\n", | |
| " self.W_q = nn.Linear(d_model, d_model)\n", | |
| " self.W_k = nn.Linear(d_model, d_model)\n", | |
| " self.W_v = nn.Linear(d_model, d_model)\n", | |
| " self.O = nn.Linear(d_model, d_model)\n", | |
| " self.d_k = d_model // n_head\n", | |
| "\n", | |
| " def forward(self, q, k, v, mask):\n", | |
| " batch_size, seq_len, _ = q.shape # (batch_size,seq_len,d_model)\n", | |
| " Q = self.W_q(q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)\n", | |
| " K = self.W_k(k).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)\n", | |
| " V = self.W_v(v).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)\n", | |
| " # (batch_size,n_head,seq_len,d_k)\n", | |
| " scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(\n", | |
| " torch.tensor(self.d_k, dtype=torch.float)\n", | |
| " ) # (batch_size,n_head,seq_len,seq_len)\n", | |
| " if mask is not None:\n", | |
| " # mask=torch.tril(torch.ones(seq_len,seq_len,dtype=bool))\n", | |
| " scores = scores.masked_fill(mask == 0, float(\"-inf\"))\n", | |
| " attn = F.softmax(scores, dim=-1)\n", | |
| " context = (\n", | |
| " torch.matmul(attn, V).transpose(1, 2).reshape(batch_size, -1, self.d_model)\n", | |
| " )\n", | |
| " # (batch_size,n_head,seq_len,d_k)\n", | |
| " output = self.O(context)\n", | |
| " # (batch_size,seq_len,d_model)\n", | |
| " return output\n" | |
| ], | |
| "execution_count": 7, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "ZTJ225quBWt2", | |
| "cell_id": "f560551557bd457ba4b8abedae2e6c76", | |
| "deepnote_cell_type": "markdown" | |
| }, | |
| "source": [ | |
| "# Encoder Layer" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "wIzvSMrlA4QF", | |
| "source_hash": "2c0b555e", | |
| "execution_start": 1711541714285, | |
| "execution_millis": 99, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "acafabedd15e4e8aba25213b2ba46a67", | |
| "deepnote_cell_type": "code" | |
| }, | |
| "source": [ | |
| "class EncoderLayer(nn.Module):\n", | |
| " def __init__(self, d_model, hidden_size, n_head, drop_prob):\n", | |
| " super(EncoderLayer, self).__init__()\n", | |
| " self.attention = MultiHeadAttention(d_model, n_head)\n", | |
| " self.norm1 = LayerNorm(d_model)\n", | |
| " self.norm2 = LayerNorm(d_model)\n", | |
| " self.ffn = FFN(d_model, hidden_size,drop_prob)\n", | |
| " self.dropout1 = nn.Dropout(drop_prob)\n", | |
| " self.dropout2 = nn.Dropout(drop_prob)\n", | |
| "\n", | |
| " def forward(self, x, mask=None):\n", | |
| " _x = x # for residual using\n", | |
| " x = self.attention(x, x, x, mask)\n", | |
| " x = self.dropout1(x)\n", | |
| " x = self.norm1(x + _x)\n", | |
| " _x = x\n", | |
| " x = self.ffn(x)\n", | |
| " x = self.dropout2(x)\n", | |
| " x = self.norm2(x + _x)\n", | |
| " return x\n" | |
| ], | |
| "execution_count": 8, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "formattedRanges": [], | |
| "cell_id": "eb565d3ed15b4ac98414d79848e4ed65", | |
| "deepnote_cell_type": "text-cell-h1", | |
| "id": "j-uENHYV3p0t" | |
| }, | |
| "source": [ | |
| "# Encoder" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "source_hash": "c97d5e14", | |
| "execution_start": 1711541714286, | |
| "execution_millis": 98, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "201536382100453e85c5a0b7ba449a1b", | |
| "deepnote_cell_type": "code", | |
| "id": "_pcFa3LD3p0t" | |
| }, | |
| "source": [ | |
| "class Encoder(nn.Module):\n", | |
| " def __init__(self,max_len,vocab_size,device,d_model,hidden_size,n_head,drop_prob,layer_num):\n", | |
| " super(Encoder,self).__init__()\n", | |
| " self.embedding=TotalEmbedding(d_model,max_len,vocab_size,device,drop_prob)\n", | |
| " self.layers=nn.ModuleList(\n", | |
| " [EncoderLayer(d_model,hidden_size,n_head,drop_prob) for _ in range(layer_num)]\n", | |
| " )\n", | |
| "\n", | |
| " def forward(self,x,padding_mask):\n", | |
| " x=self.embedding(x)\n", | |
| " for layer in self.layers:\n", | |
| " x=layer(x,padding_mask)\n", | |
| " return x" | |
| ], | |
| "execution_count": 9, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "cell_id": "c557cf875f044069af27c1169f8fb4ec", | |
| "deepnote_cell_type": "markdown", | |
| "id": "ZRDNyLRv3p0u" | |
| }, | |
| "source": [ | |
| "# Decoder Layer" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "source_hash": "27693686", | |
| "execution_start": 1711541714286, | |
| "execution_millis": 99, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "91e8bf93454646258229b2a1062f855d", | |
| "deepnote_cell_type": "code", | |
| "id": "uA3I-14k3p0u" | |
| }, | |
| "source": [ | |
| "class DecoderLayer(nn.Module):\n", | |
| " def __init__(self, d_model, hidden_size, n_head, drop_prob):\n", | |
| " super(DecoderLayer,self).__init__()\n", | |
| " self.attention=MultiHeadAttention(d_model,n_head)\n", | |
| " self.cross_attention=MultiHeadAttention(d_model,n_head)\n", | |
| " self.norm1=LayerNorm(d_model)\n", | |
| " self.norm2=LayerNorm(d_model)\n", | |
| " self.norm3=LayerNorm(d_model)\n", | |
| " self.ffn=FFN(d_model,hidden_size,drop_prob)\n", | |
| " self.drop_out1=nn.Dropout(drop_prob)\n", | |
| " self.drop_out2=nn.Dropout(drop_prob)\n", | |
| " self.drop_out3=nn.Dropout(drop_prob)\n", | |
| " def forward(self,enc,dec,src_mask,trg_mask):\n", | |
| " _x=dec\n", | |
| " x=self.attention(dec,dec,dec,trg_mask)\n", | |
| " x=self.drop_out1(x)\n", | |
| " x=self.norm1(x+_x)\n", | |
| "\n", | |
| " _x=x\n", | |
| " x=self.cross_attention(dec,enc,enc,src_mask) # future mask contain future mask and padding mask\n", | |
| " x=self.drop_out2(x)\n", | |
| " x=self.norm2(x+_x)\n", | |
| "\n", | |
| " _x=x\n", | |
| " x=self.ffn(x)\n", | |
| " x=self.drop_out3(x)\n", | |
| " x=self.norm3(x+_x)\n", | |
| " return x\n" | |
| ], | |
| "execution_count": 10, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "formattedRanges": [], | |
| "cell_id": "972dc805b64c49509134d885e78a68c5", | |
| "deepnote_cell_type": "text-cell-h1", | |
| "id": "mDDoIh8X3p0u" | |
| }, | |
| "source": [ | |
| "# Decoder" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "source_hash": "28e11345", | |
| "execution_start": 1711541714334, | |
| "execution_millis": 268, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "0d4c98a6ef62467aa88ecd7bb0bb6c6b", | |
| "deepnote_cell_type": "code", | |
| "id": "zoA-gOqJ3p0u" | |
| }, | |
| "source": [ | |
| "class Decoder(nn.Module):\n", | |
| " def __init__(self,max_len,vocab_size,device,d_model,hidden_size,n_head,drop_prob,layer_num):\n", | |
| " super(Decoder,self).__init__()\n", | |
| " self.embedding=TotalEmbedding(d_model,max_len,vocab_size,device,drop_prob)\n", | |
| " self.layers=nn.ModuleList(\n", | |
| " [DecoderLayer(d_model,hidden_size,n_head,drop_prob) for _ in range(layer_num)]\n", | |
| " )\n", | |
| " self.linear=nn.Linear(d_model,vocab_size)\n", | |
| " def forward(self,enc,dec,src_mask,trg_mask):\n", | |
| " dec=self.embedding(dec)\n", | |
| " for layer in self.layers:\n", | |
| " dec=layer(enc,dec,src_mask,trg_mask)\n", | |
| " dec=self.linear(dec)\n", | |
| " out=F.softmax(dec,dim=-1)\n", | |
| " return out" | |
| ], | |
| "execution_count": 11, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "formattedRanges": [], | |
| "cell_id": "441bc8472a164ed1978a6edce3861289", | |
| "deepnote_cell_type": "text-cell-h1", | |
| "id": "AdWIWuvy3p0u" | |
| }, | |
| "source": [ | |
| "# Transformer" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "source_hash": "92b4cdc8", | |
| "execution_start": 1711541889626, | |
| "execution_millis": 33, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "d1571b0984fb4fe8962f9b70742c85d0", | |
| "deepnote_cell_type": "code", | |
| "id": "qwrMo5CS3p0u" | |
| }, | |
| "source": [ | |
| "class Transformer(nn.Module):\n", | |
| " def __init__(self,src_pad_idx,trg_pad_idx,max_len,enc_vocab_size,dec_vocab_size,device,d_model,hidden_size,n_head,drop_prob,layer_num):\n", | |
| " super(Transformer,self).__init__()\n", | |
| " self.encoder=Encoder(max_len,enc_vocab_size,device,d_model,hidden_size,n_head,drop_prob,layer_num)\n", | |
| " self.decoder=Decoder(max_len,dec_vocab_size,device,d_model,hidden_size,n_head,drop_prob,layer_num)\n", | |
| " self.src_pad_idx = src_pad_idx\n", | |
| " self.trg_pad_idx = trg_pad_idx\n", | |
| " self.device=device\n", | |
| "\n", | |
| " def forward(self, src, trg):\n", | |
| " src_mask = self.make_src_mask(src)\n", | |
| " trg_mask = self.make_trg_mask(trg)\n", | |
| " enc = self.encoder(src, src_mask)\n", | |
| " output = self.decoder(enc, trg, src_mask, trg_mask)\n", | |
| " return output # 确保返回output\n", | |
| "\n", | |
| " def make_src_mask(self, src):\n", | |
| " src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)\n", | |
| " return src_mask\n", | |
| "\n", | |
| " def make_trg_mask(self, trg):\n", | |
| "\n", | |
| " trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3) # 确保维度正确 (batch_size,n_head,seq_len,seq_len)\n", | |
| " trg_len = trg.size(1)\n", | |
| " trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).type(torch.bool) # (seq_len,seq_len)\n", | |
| " trg_mask = trg_pad_mask & trg_sub_mask\n", | |
| " return trg_mask" | |
| ], | |
| "execution_count": 12, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "formattedRanges": [], | |
| "cell_id": "3457c47436f8475fa5f3581c188bf041", | |
| "deepnote_cell_type": "text-cell-h1", | |
| "id": "k6qS8saW3p0u" | |
| }, | |
| "source": [ | |
| "# Experiment" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "source_hash": "bea0b6b2", | |
| "execution_start": 1711541895993, | |
| "execution_millis": 674, | |
| "deepnote_to_be_reexecuted": false, | |
| "cell_id": "f5664ef823ea43a89f3b0cf4fe187b33", | |
| "deepnote_cell_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "gKPW5rjg3p0v", | |
| "outputId": "6e99a32b-02e9-4a88-b103-53f62ab0f9b0" | |
| }, | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "# from your_transformer_module import Transformer # 确保引入了你的 Transformer 类\n", | |
| "\n", | |
| "# 1. 准备数据\n", | |
| "batch_size = 32\n", | |
| "src_seq_len = 10\n", | |
| "trg_seq_len = 15\n", | |
| "src_pad_idx = 0\n", | |
| "trg_pad_idx = 0\n", | |
| "max_len = 20\n", | |
| "enc_vocab_size = 1000\n", | |
| "dec_vocab_size = 1000\n", | |
| "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
| "d_model = 512\n", | |
| "hidden_size = 2048\n", | |
| "n_head = 8\n", | |
| "drop_prob = 0.1\n", | |
| "layer_num = 6\n", | |
| "\n", | |
| "src = torch.randint(1, enc_vocab_size, (batch_size, src_seq_len)).to(device)\n", | |
| "trg = torch.randint(1, dec_vocab_size, (batch_size, trg_seq_len)).to(device)\n", | |
| "\n", | |
| "# 2. 初始化模型\n", | |
| "model = Transformer(src_pad_idx, trg_pad_idx, max_len, enc_vocab_size, dec_vocab_size, device, d_model, hidden_size, n_head, drop_prob, layer_num).to(device)\n", | |
| "\n", | |
| "# 3. 前向传播测试\n", | |
| "try:\n", | |
| " output = model(src, trg[:, :-1]) # trg的输入去掉最后一个,以便于教师强制\n", | |
| " assert output.shape == (batch_size, trg_seq_len - 1, dec_vocab_size), \"输出形状不匹配\"\n", | |
| " print(\"前向传播测试通过!\")\n", | |
| "except AssertionError as e:\n", | |
| " print(f\"前向传播测试失败: {e}\")\n", | |
| "\n", | |
| "# 4. 掩码生成测试\n", | |
| "src_mask = model.make_src_mask(src)\n", | |
| "trg_mask = model.make_trg_mask(trg[:, :-1])\n", | |
| "\n", | |
| "try:\n", | |
| " assert src_mask.shape == (batch_size, 1, 1, src_seq_len), \"源掩码形状不正确\"\n", | |
| " assert trg_mask.shape == (batch_size, 1, trg_seq_len - 1, trg_seq_len - 1), \"目标掩码形状不正确\"\n", | |
| " print(\"掩码生成测试通过!\")\n", | |
| "except AssertionError as e:\n", | |
| " print(f\"掩码生成测试失败: {e}\")\n" | |
| ], | |
| "execution_count": 13, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "前向传播测试通过!\n", | |
| "掩码生成测试通过!\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "source_hash": "b623e53d", | |
| "deepnote_to_be_reexecuted": true, | |
| "cell_id": "1d1a09a9e1a04e31b2a39298d6faad08", | |
| "deepnote_cell_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "khDgNI_93p0v", | |
| "outputId": "07c08084-06da-4ecc-8e23-4bd61a267539" | |
| }, | |
| "source": [ | |
| "import torch\n", | |
| "from torch.utils.data import DataLoader, TensorDataset\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "# 假设的模型参数\n", | |
| "src_pad_idx = 0\n", | |
| "trg_pad_idx = 0\n", | |
| "max_len = 10\n", | |
| "enc_vocab_size = 20 # 假设词汇量大小\n", | |
| "dec_vocab_size = 20\n", | |
| "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
| "d_model = 512\n", | |
| "hidden_size = 2048\n", | |
| "n_head = 8\n", | |
| "drop_prob = 0.1\n", | |
| "layer_num = 6\n", | |
| "\n", | |
| "# 初始化模型\n", | |
| "model = Transformer(src_pad_idx, trg_pad_idx, max_len, enc_vocab_size, dec_vocab_size, device, d_model, hidden_size, n_head, drop_prob, layer_num)\n", | |
| "model.to(device)\n", | |
| "\n", | |
| "# 生成简单的合成数据\n", | |
| "def generate_data(batch_size, seq_len, vocab_size):\n", | |
| " data = torch.randint(1, vocab_size, (batch_size, seq_len)) # 避免生成pad token\n", | |
| " return data\n", | |
| "\n", | |
| "batch_size = 32\n", | |
| "seq_len = 10\n", | |
| "\n", | |
| "# 生成数据\n", | |
| "src_data = generate_data(batch_size, seq_len, enc_vocab_size)\n", | |
| "trg_data = src_data.clone() # 输入和输出相同\n", | |
| "\n", | |
| "# 数据加载器\n", | |
| "dataset = TensorDataset(src_data, trg_data)\n", | |
| "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", | |
| "\n", | |
| "# 简单训练循环\n", | |
| "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", | |
| "criterion = torch.nn.CrossEntropyLoss(ignore_index=trg_pad_idx)\n", | |
| "\n", | |
| "for epoch in range(100): # 训练一个epoch进行测试\n", | |
| " model.train()\n", | |
| " total_loss = 0\n", | |
| " for src, trg in dataloader:\n", | |
| " src, trg = src.to(device), trg.to(device)\n", | |
| " optimizer.zero_grad()\n", | |
| " output = model(src, trg[:, :-1]) # 忽略序列的最后一个元素\n", | |
| " output = output.reshape(-1, output.shape[-1])\n", | |
| " trg = trg[:, 1:].reshape(-1) # 忽略序列的第一个元素\n", | |
| " loss = criterion(output, trg)\n", | |
| " loss.backward()\n", | |
| " optimizer.step()\n", | |
| " total_loss += loss.item()\n", | |
| " print(f\"Epoch {epoch}, Loss: {total_loss / len(dataloader)}\")\n", | |
| "\n", | |
| "# 预期输出: 观察到的损失应该随着时间的推移而减少,表明模型正在学习复制任务\n" | |
| ], | |
| "execution_count": 17, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Epoch 0, Loss: 2.995344877243042\n", | |
| "Epoch 1, Loss: 2.984524726867676\n", | |
| "Epoch 2, Loss: 2.9938313961029053\n", | |
| "Epoch 3, Loss: 2.9807004928588867\n", | |
| "Epoch 4, Loss: 2.9887120723724365\n", | |
| "Epoch 5, Loss: 2.998847246170044\n", | |
| "Epoch 6, Loss: 2.9817023277282715\n", | |
| "Epoch 7, Loss: 2.9803860187530518\n", | |
| "Epoch 8, Loss: 2.9835989475250244\n", | |
| "Epoch 9, Loss: 2.9836642742156982\n", | |
| "Epoch 10, Loss: 2.9835991859436035\n", | |
| "Epoch 11, Loss: 2.9806711673736572\n", | |
| "Epoch 12, Loss: 2.9795029163360596\n", | |
| "Epoch 13, Loss: 2.982041358947754\n", | |
| "Epoch 14, Loss: 2.980454444885254\n", | |
| "Epoch 15, Loss: 2.981685161590576\n", | |
| "Epoch 16, Loss: 2.980349540710449\n", | |
| "Epoch 17, Loss: 2.980302572250366\n", | |
| "Epoch 18, Loss: 2.981624126434326\n", | |
| "Epoch 19, Loss: 2.979287624359131\n", | |
| "Epoch 20, Loss: 2.9816741943359375\n", | |
| "Epoch 21, Loss: 2.9805760383605957\n", | |
| "Epoch 22, Loss: 2.9801182746887207\n", | |
| "Epoch 23, Loss: 2.9790592193603516\n", | |
| "Epoch 24, Loss: 2.9804089069366455\n", | |
| "Epoch 25, Loss: 2.980847120285034\n", | |
| "Epoch 26, Loss: 2.97930645942688\n", | |
| "Epoch 27, Loss: 2.9797005653381348\n", | |
| "Epoch 28, Loss: 2.9782049655914307\n", | |
| "Epoch 29, Loss: 2.979945182800293\n", | |
| "Epoch 30, Loss: 2.977987051010132\n", | |
| "Epoch 31, Loss: 2.9785232543945312\n", | |
| "Epoch 32, Loss: 2.9788577556610107\n", | |
| "Epoch 33, Loss: 2.9791769981384277\n", | |
| "Epoch 34, Loss: 2.978320360183716\n", | |
| "Epoch 35, Loss: 2.9786667823791504\n", | |
| "Epoch 36, Loss: 2.978853464126587\n", | |
| "Epoch 37, Loss: 2.978756904602051\n", | |
| "Epoch 38, Loss: 2.9776389598846436\n", | |
| "Epoch 39, Loss: 2.9791746139526367\n", | |
| "Epoch 40, Loss: 2.978610038757324\n", | |
| "Epoch 41, Loss: 2.980536460876465\n", | |
| "Epoch 42, Loss: 2.979384422302246\n", | |
| "Epoch 43, Loss: 2.9789557456970215\n", | |
| "Epoch 44, Loss: 2.978343963623047\n", | |
| "Epoch 45, Loss: 2.9789674282073975\n", | |
| "Epoch 46, Loss: 2.9782252311706543\n", | |
| "Epoch 47, Loss: 2.9781506061553955\n", | |
| "Epoch 48, Loss: 2.9780097007751465\n", | |
| "Epoch 49, Loss: 2.9773736000061035\n", | |
| "Epoch 50, Loss: 2.9779841899871826\n", | |
| "Epoch 51, Loss: 2.9753427505493164\n", | |
| "Epoch 52, Loss: 2.979438543319702\n", | |
| "Epoch 53, Loss: 2.9791324138641357\n", | |
| "Epoch 54, Loss: 2.9781386852264404\n", | |
| "Epoch 55, Loss: 2.9788925647735596\n", | |
| "Epoch 56, Loss: 2.976363182067871\n", | |
| "Epoch 57, Loss: 2.9784836769104004\n", | |
| "Epoch 58, Loss: 2.9774186611175537\n", | |
| "Epoch 59, Loss: 2.976188898086548\n", | |
| "Epoch 60, Loss: 2.9777863025665283\n", | |
| "Epoch 61, Loss: 2.9790074825286865\n", | |
| "Epoch 62, Loss: 2.9771623611450195\n", | |
| "Epoch 63, Loss: 2.978694438934326\n", | |
| "Epoch 64, Loss: 2.977710723876953\n", | |
| "Epoch 65, Loss: 2.977355718612671\n", | |
| "Epoch 66, Loss: 2.9788594245910645\n", | |
| "Epoch 67, Loss: 2.9797263145446777\n", | |
| "Epoch 68, Loss: 2.9768307209014893\n", | |
| "Epoch 69, Loss: 2.976647138595581\n", | |
| "Epoch 70, Loss: 2.979998826980591\n", | |
| "Epoch 71, Loss: 2.9781081676483154\n", | |
| "Epoch 72, Loss: 2.9787964820861816\n", | |
| "Epoch 73, Loss: 2.9788260459899902\n", | |
| "Epoch 74, Loss: 2.977302074432373\n", | |
| "Epoch 75, Loss: 2.9781789779663086\n", | |
| "Epoch 76, Loss: 2.9773879051208496\n", | |
| "Epoch 77, Loss: 2.97847318649292\n", | |
| "Epoch 78, Loss: 2.9774057865142822\n", | |
| "Epoch 79, Loss: 2.9770147800445557\n", | |
| "Epoch 80, Loss: 2.9789209365844727\n", | |
| "Epoch 81, Loss: 2.977285861968994\n", | |
| "Epoch 82, Loss: 2.9775390625\n", | |
| "Epoch 83, Loss: 2.979658365249634\n", | |
| "Epoch 84, Loss: 2.9776415824890137\n", | |
| "Epoch 85, Loss: 2.979929208755493\n", | |
| "Epoch 86, Loss: 2.979032278060913\n", | |
| "Epoch 87, Loss: 2.9790756702423096\n", | |
| "Epoch 88, Loss: 2.9792330265045166\n", | |
| "Epoch 89, Loss: 2.9780335426330566\n", | |
| "Epoch 90, Loss: 2.9781672954559326\n", | |
| "Epoch 91, Loss: 2.976395845413208\n", | |
| "Epoch 92, Loss: 2.9773950576782227\n", | |
| "Epoch 93, Loss: 2.9803178310394287\n", | |
| "Epoch 94, Loss: 2.9772472381591797\n", | |
| "Epoch 95, Loss: 2.9769339561462402\n", | |
| "Epoch 96, Loss: 2.980050802230835\n", | |
| "Epoch 97, Loss: 2.9778597354888916\n", | |
| "Epoch 98, Loss: 2.977893829345703\n", | |
| "Epoch 99, Loss: 2.977822780609131\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=7ee49933-c18c-41ac-a77e-19b1d981850e' target=\"_blank\">\n", | |
| "<img alt='Created in deepnote.com' style='display:inline;max-height:16px;margin:0px;margin-right:7.5px;' src='data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiPz4KPHN2ZyB3aWR0aD0iODBweCIgaGVpZ2h0PSI4MHB4IiB2aWV3Qm94PSIwIDAgODAgODAiIHZlcnNpb249IjEuMSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiB4bWxuczp4bGluaz0iaHR0cDovL3d3dy53My5vcmcvMTk5OS94bGluayI+CiAgICA8IS0tIEdlbmVyYXRvcjogU2tldGNoIDU0LjEgKDc2NDkwKSAtIGh0dHBzOi8vc2tldGNoYXBwLmNvbSAtLT4KICAgIDx0aXRsZT5Hcm91cCAzPC90aXRsZT4KICAgIDxkZXNjPkNyZWF0ZWQgd2l0aCBTa2V0Y2guPC9kZXNjPgogICAgPGcgaWQ9IkxhbmRpbmciIHN0cm9rZT0ibm9uZSIgc3Ryb2tlLXdpZHRoPSIxIiBmaWxsPSJub25lIiBmaWxsLXJ1bGU9ImV2ZW5vZGQiPgogICAgICAgIDxnIGlkPSJBcnRib2FyZCIgdHJhbnNmb3JtPSJ0cmFuc2xhdGUoLTEyMzUuMDAwMDAwLCAtNzkuMDAwMDAwKSI+CiAgICAgICAgICAgIDxnIGlkPSJHcm91cC0zIiB0cmFuc2Zvcm09InRyYW5zbGF0ZSgxMjM1LjAwMDAwMCwgNzkuMDAwMDAwKSI+CiAgICAgICAgICAgICAgICA8cG9seWdvbiBpZD0iUGF0aC0yMCIgZmlsbD0iIzAyNjVCNCIgcG9pbnRzPSIyLjM3NjIzNzYyIDgwIDM4LjA0NzY2NjcgODAgNTcuODIxNzgyMiA3My44MDU3NTkyIDU3LjgyMTc4MjIgMzIuNzU5MjczOSAzOS4xNDAyMjc4IDMxLjY4MzE2ODMiPjwvcG9seWdvbj4KICAgICAgICAgICAgICAgIDxwYXRoIGQ9Ik0zNS4wMDc3MTgsODAgQzQyLjkwNjIwMDcsNzYuNDU0OTM1OCA0Ny41NjQ5MTY3LDcxLjU0MjI2NzEgNDguOTgzODY2LDY1LjI2MTk5MzkgQzUxLjExMjI4OTksNTUuODQxNTg0MiA0MS42NzcxNzk1LDQ5LjIxMjIyODQgMjUuNjIzOTg0Niw0OS4yMTIyMjg0IEMyNS40ODQ5Mjg5LDQ5LjEyNjg0NDggMjkuODI2MTI5Niw0My4yODM4MjQ4IDM4LjY0NzU4NjksMzEuNjgzMTY4MyBMNzIuODcxMjg3MSwzMi41NTQ0MjUgTDY1LjI4MDk3Myw2Ny42NzYzNDIxIEw1MS4xMTIyODk5LDc3LjM3NjE0NCBMMzUuMDA3NzE4LDgwIFoiIGlkPSJQYXRoLTIyIiBmaWxsPSIjMDAyODY4Ij48L3BhdGg+CiAgICAgICAgICAgICAgICA8cGF0aCBkPSJNMCwzNy43MzA0NDA1IEwyNy4xMTQ1MzcsMC4yNTcxMTE0MzYgQzYyLjM3MTUxMjMsLTEuOTkwNzE3MDEgODAsMTAuNTAwMzkyNyA4MCwzNy43MzA0NDA1IEM4MCw2NC45NjA0ODgyIDY0Ljc3NjUwMzgsNzkuMDUwMzQxNCAzNC4zMjk1MTEzLDgwIEM0Ny4wNTUzNDg5LDc3LjU2NzA4MDggNTMuNDE4MjY3Nyw3MC4zMTM2MTAzIDUzLjQxODI2NzcsNTguMjM5NTg4NSBDNTMuNDE4MjY3Nyw0MC4xMjg1NTU3IDM2LjMwMzk1NDQsMzcuNzMwNDQwNSAyNS4yMjc0MTcsMzcuNzMwNDQwNSBDMTcuODQzMDU4NiwzNy43MzA0NDA1IDkuNDMzOTE5NjYsMzcuNzMwNDQwNSAwLDM3LjczMDQ0MDUgWiIgaWQ9IlBhdGgtMTkiIGZpbGw9IiMzNzkzRUYiPjwvcGF0aD4KICAgICAgICAgICAgPC9nPgogICAgICAgIDwvZz4KICAgIDwvZz4KPC9zdmc+' > </img>\n", | |
| "Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>" | |
| ], | |
| "metadata": { | |
| "created_in_deepnote_cell": true, | |
| "deepnote_cell_type": "markdown", | |
| "id": "lZ_LDbtZ3p0v" | |
| } | |
| } | |
| ], | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "deepnote_persisted_session": { | |
| "createdAt": "2024-03-27T08:00:50.222Z" | |
| }, | |
| "deepnote_full_width": true, | |
| "deepnote_notebook_id": "15281b94ce5b48a38d2ff6c59c47cf1c", | |
| "deepnote_execution_queue": [], | |
| "colab": { | |
| "provenance": [], | |
| "gpuType": "T4", | |
| "include_colab_link": true | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "accelerator": "GPU" | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment