Skip to content

Instantly share code, notes, and snippets.

@BestLemoon
Last active March 27, 2024 12:50
Show Gist options
  • Select an option

  • Save BestLemoon/8f1516947f1f7bd8e4f5fc0689316810 to your computer and use it in GitHub Desktop.

Select an option

Save BestLemoon/8f1516947f1f7bd8e4f5fc0689316810 to your computer and use it in GitHub Desktop.
My_Handmade_Transformer
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/BestLemoon/8f1516947f1f7bd8e4f5fc0689316810/transformer_learning-1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JM3QHfPzh--b",
"cell_id": "b848c37ce6eb4dae914a16b471722c04",
"deepnote_cell_type": "markdown"
},
"source": [
"# Embeddings"
]
},
{
"cell_type": "code",
"metadata": {
"source_hash": "6bd170bb",
"execution_start": 1711541713717,
"execution_millis": 505,
"deepnote_to_be_reexecuted": false,
"cell_id": "bb423e32ffbb462281fb21daeb7b1a09",
"deepnote_cell_type": "code",
"id": "eZnMSJ753p0p"
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sqXjgw_oiBVn",
"source_hash": "426b3b40",
"execution_start": 1711541714229,
"execution_millis": 96,
"deepnote_to_be_reexecuted": false,
"cell_id": "e675994c53664b3095bcf12e02ffd511",
"deepnote_cell_type": "code"
},
"source": [
"class TokenEmbedding(nn.Embedding):\n",
" def __init__(self, vocab_size, d_model):\n",
" super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)\n"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Y3obgd9xjlmR",
"source_hash": "7fea523b",
"execution_start": 1711541714237,
"execution_millis": 89,
"deepnote_to_be_reexecuted": false,
"cell_id": "d83a8689d2c74f08b888fef30d2bae3b",
"deepnote_cell_type": "code"
},
"source": [
"class PositionalEmbedding(nn.Module):\n",
" def __init__(self,d_model,max_len,device):\n",
" super(PositionalEmbedding,self).__init__()\n",
" self.pe=torch.zeros(max_len,d_model,device=device) # (max_len,d_model)\n",
" self.pe.requires_grad_(False) # no grad bc no change to pe during transformer\n",
" pos = torch.arange(0, max_len, dtype=torch.float, device=device).unsqueeze(1) # (max_len, 1)\n",
" _2i = torch.arange(0, d_model, step=2, dtype=torch.float, device=device) # (d_model/2,)\n",
"\n",
" self.pe[:,0::2]=torch.sin(pos / (10000 ** (_2i / d_model))) # (max_len,d_model/2)\n",
" self.pe[:,1::2]=torch.cos(pos / (10000 ** (_2i / d_model))) # (max_len,d_model/2)\n",
" def forward(self,x):\n",
" seq_len=x.size(1)\n",
" return self.pe[:seq_len,:]"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "hIyDb3Z-uX0X",
"source_hash": "dd57df35",
"execution_start": 1711541714258,
"execution_millis": 101,
"deepnote_to_be_reexecuted": false,
"cell_id": "f34c0740aebb4050a5d97ae600ef0e91",
"deepnote_cell_type": "code"
},
"source": [
"class TotalEmbedding(nn.Module):\n",
" def __init__(self, d_model, max_len, vocab_size, device, drop_prob):\n",
" super(TotalEmbedding, self).__init__()\n",
" self.TokenEmbedding = TokenEmbedding(vocab_size, d_model)\n",
" self.PositionalEmbedding = PositionalEmbedding(d_model, max_len, device)\n",
" self.dropout = nn.Dropout(p=drop_prob)\n",
"\n",
" def forward(self, x):\n",
" # x:(batch_size, seq_len)\n",
" token = self.TokenEmbedding(x) # (batch_size, seq_len, d_model)\n",
" pe = self.PositionalEmbedding(x) # (seq_len,d_model)\n",
" return self.dropout(token + pe)\n"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "nlnS6PqJBADg",
"cell_id": "a8a61941bdb4420c8fd66f84e65ad44d",
"deepnote_cell_type": "markdown"
},
"source": [
"# LayerNorm"
]
},
{
"cell_type": "code",
"metadata": {
"id": "jaN9Oauvx5MG",
"source_hash": "2e7186ce",
"execution_start": 1711541714259,
"execution_millis": 101,
"deepnote_to_be_reexecuted": false,
"cell_id": "0be27f0e5ce341f19e49ee91163a7cb6",
"deepnote_cell_type": "code"
},
"source": [
"class LayerNorm(nn.Module):\n",
" def __init__(self, d_model, eps=1e-10):\n",
" super(LayerNorm, self).__init__()\n",
" self.gemma = nn.Parameter(torch.ones(d_model))\n",
" self.beta = nn.Parameter(torch.zeros(d_model))\n",
" self.eps = eps\n",
"\n",
" def forward(self, x):\n",
" mean = x.mean(-1, keepdim=True) # 保持维度方便计算\n",
" var = x.var(-1, keepdim=True, unbiased=False) # 使用有偏估计\n",
" out = (x - mean) / torch.sqrt(var + self.eps)\n",
" out = out * self.gemma + self.beta\n",
" return out\n"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ogKEyTefBBmJ",
"cell_id": "7b70e0a3e0e845babb19631b73d1fb8e",
"deepnote_cell_type": "markdown"
},
"source": [
"# FFN"
]
},
{
"cell_type": "code",
"metadata": {
"id": "nWZjqG_b_tTe",
"source_hash": "b1721ee9",
"execution_start": 1711541714260,
"execution_millis": 100,
"deepnote_to_be_reexecuted": false,
"cell_id": "12596bebdd504d369bb5648ddc32bfe0",
"deepnote_cell_type": "code"
},
"source": [
"class FFN(nn.Module):\n",
" def __init__(self, d_model, hidden_size, drop_out=0.1):\n",
" super(FFN, self).__init__()\n",
" self.fc1 = nn.Linear(d_model, hidden_size)\n",
" self.fc2 = nn.Linear(hidden_size, d_model)\n",
" self.dropout = nn.Dropout(drop_out)\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(self.fc1(x))\n",
" out = self.fc2(self.dropout(x))\n",
" return out\n"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "CE-wnzEwBDix",
"cell_id": "0bfee901eaf148ea8e695164b7ce56e6",
"deepnote_cell_type": "markdown"
},
"source": [
"# MultiHead-Attention"
]
},
{
"cell_type": "code",
"metadata": {
"id": "hQmszCuyQAnH",
"source_hash": "c3c12d3b",
"execution_start": 1711541714263,
"execution_millis": 120,
"deepnote_to_be_reexecuted": false,
"cell_id": "6a46cc717d9049f5b41ea37d263c3859",
"deepnote_cell_type": "code"
},
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_model, n_head):\n",
" super(MultiHeadAttention, self).__init__()\n",
" self.d_model = d_model\n",
" self.n_head = n_head\n",
" self.W_q = nn.Linear(d_model, d_model)\n",
" self.W_k = nn.Linear(d_model, d_model)\n",
" self.W_v = nn.Linear(d_model, d_model)\n",
" self.O = nn.Linear(d_model, d_model)\n",
" self.d_k = d_model // n_head\n",
"\n",
" def forward(self, q, k, v, mask):\n",
" batch_size, seq_len, _ = q.shape # (batch_size,seq_len,d_model)\n",
" Q = self.W_q(q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)\n",
" K = self.W_k(k).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)\n",
" V = self.W_v(v).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)\n",
" # (batch_size,n_head,seq_len,d_k)\n",
" scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(\n",
" torch.tensor(self.d_k, dtype=torch.float)\n",
" ) # (batch_size,n_head,seq_len,seq_len)\n",
" if mask is not None:\n",
" # mask=torch.tril(torch.ones(seq_len,seq_len,dtype=bool))\n",
" scores = scores.masked_fill(mask == 0, float(\"-inf\"))\n",
" attn = F.softmax(scores, dim=-1)\n",
" context = (\n",
" torch.matmul(attn, V).transpose(1, 2).reshape(batch_size, -1, self.d_model)\n",
" )\n",
" # (batch_size,n_head,seq_len,d_k)\n",
" output = self.O(context)\n",
" # (batch_size,seq_len,d_model)\n",
" return output\n"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZTJ225quBWt2",
"cell_id": "f560551557bd457ba4b8abedae2e6c76",
"deepnote_cell_type": "markdown"
},
"source": [
"# Encoder Layer"
]
},
{
"cell_type": "code",
"metadata": {
"id": "wIzvSMrlA4QF",
"source_hash": "2c0b555e",
"execution_start": 1711541714285,
"execution_millis": 99,
"deepnote_to_be_reexecuted": false,
"cell_id": "acafabedd15e4e8aba25213b2ba46a67",
"deepnote_cell_type": "code"
},
"source": [
"class EncoderLayer(nn.Module):\n",
" def __init__(self, d_model, hidden_size, n_head, drop_prob):\n",
" super(EncoderLayer, self).__init__()\n",
" self.attention = MultiHeadAttention(d_model, n_head)\n",
" self.norm1 = LayerNorm(d_model)\n",
" self.norm2 = LayerNorm(d_model)\n",
" self.ffn = FFN(d_model, hidden_size,drop_prob)\n",
" self.dropout1 = nn.Dropout(drop_prob)\n",
" self.dropout2 = nn.Dropout(drop_prob)\n",
"\n",
" def forward(self, x, mask=None):\n",
" _x = x # for residual using\n",
" x = self.attention(x, x, x, mask)\n",
" x = self.dropout1(x)\n",
" x = self.norm1(x + _x)\n",
" _x = x\n",
" x = self.ffn(x)\n",
" x = self.dropout2(x)\n",
" x = self.norm2(x + _x)\n",
" return x\n"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"formattedRanges": [],
"cell_id": "eb565d3ed15b4ac98414d79848e4ed65",
"deepnote_cell_type": "text-cell-h1",
"id": "j-uENHYV3p0t"
},
"source": [
"# Encoder"
]
},
{
"cell_type": "code",
"metadata": {
"source_hash": "c97d5e14",
"execution_start": 1711541714286,
"execution_millis": 98,
"deepnote_to_be_reexecuted": false,
"cell_id": "201536382100453e85c5a0b7ba449a1b",
"deepnote_cell_type": "code",
"id": "_pcFa3LD3p0t"
},
"source": [
"class Encoder(nn.Module):\n",
" def __init__(self,max_len,vocab_size,device,d_model,hidden_size,n_head,drop_prob,layer_num):\n",
" super(Encoder,self).__init__()\n",
" self.embedding=TotalEmbedding(d_model,max_len,vocab_size,device,drop_prob)\n",
" self.layers=nn.ModuleList(\n",
" [EncoderLayer(d_model,hidden_size,n_head,drop_prob) for _ in range(layer_num)]\n",
" )\n",
"\n",
" def forward(self,x,padding_mask):\n",
" x=self.embedding(x)\n",
" for layer in self.layers:\n",
" x=layer(x,padding_mask)\n",
" return x"
],
"execution_count": 9,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"cell_id": "c557cf875f044069af27c1169f8fb4ec",
"deepnote_cell_type": "markdown",
"id": "ZRDNyLRv3p0u"
},
"source": [
"# Decoder Layer"
]
},
{
"cell_type": "code",
"metadata": {
"source_hash": "27693686",
"execution_start": 1711541714286,
"execution_millis": 99,
"deepnote_to_be_reexecuted": false,
"cell_id": "91e8bf93454646258229b2a1062f855d",
"deepnote_cell_type": "code",
"id": "uA3I-14k3p0u"
},
"source": [
"class DecoderLayer(nn.Module):\n",
" def __init__(self, d_model, hidden_size, n_head, drop_prob):\n",
" super(DecoderLayer,self).__init__()\n",
" self.attention=MultiHeadAttention(d_model,n_head)\n",
" self.cross_attention=MultiHeadAttention(d_model,n_head)\n",
" self.norm1=LayerNorm(d_model)\n",
" self.norm2=LayerNorm(d_model)\n",
" self.norm3=LayerNorm(d_model)\n",
" self.ffn=FFN(d_model,hidden_size,drop_prob)\n",
" self.drop_out1=nn.Dropout(drop_prob)\n",
" self.drop_out2=nn.Dropout(drop_prob)\n",
" self.drop_out3=nn.Dropout(drop_prob)\n",
" def forward(self,enc,dec,src_mask,trg_mask):\n",
" _x=dec\n",
" x=self.attention(dec,dec,dec,trg_mask)\n",
" x=self.drop_out1(x)\n",
" x=self.norm1(x+_x)\n",
"\n",
" _x=x\n",
" x=self.cross_attention(dec,enc,enc,src_mask) # future mask contain future mask and padding mask\n",
" x=self.drop_out2(x)\n",
" x=self.norm2(x+_x)\n",
"\n",
" _x=x\n",
" x=self.ffn(x)\n",
" x=self.drop_out3(x)\n",
" x=self.norm3(x+_x)\n",
" return x\n"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"formattedRanges": [],
"cell_id": "972dc805b64c49509134d885e78a68c5",
"deepnote_cell_type": "text-cell-h1",
"id": "mDDoIh8X3p0u"
},
"source": [
"# Decoder"
]
},
{
"cell_type": "code",
"metadata": {
"source_hash": "28e11345",
"execution_start": 1711541714334,
"execution_millis": 268,
"deepnote_to_be_reexecuted": false,
"cell_id": "0d4c98a6ef62467aa88ecd7bb0bb6c6b",
"deepnote_cell_type": "code",
"id": "zoA-gOqJ3p0u"
},
"source": [
"class Decoder(nn.Module):\n",
" def __init__(self,max_len,vocab_size,device,d_model,hidden_size,n_head,drop_prob,layer_num):\n",
" super(Decoder,self).__init__()\n",
" self.embedding=TotalEmbedding(d_model,max_len,vocab_size,device,drop_prob)\n",
" self.layers=nn.ModuleList(\n",
" [DecoderLayer(d_model,hidden_size,n_head,drop_prob) for _ in range(layer_num)]\n",
" )\n",
" self.linear=nn.Linear(d_model,vocab_size)\n",
" def forward(self,enc,dec,src_mask,trg_mask):\n",
" dec=self.embedding(dec)\n",
" for layer in self.layers:\n",
" dec=layer(enc,dec,src_mask,trg_mask)\n",
" dec=self.linear(dec)\n",
" out=F.softmax(dec,dim=-1)\n",
" return out"
],
"execution_count": 11,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"formattedRanges": [],
"cell_id": "441bc8472a164ed1978a6edce3861289",
"deepnote_cell_type": "text-cell-h1",
"id": "AdWIWuvy3p0u"
},
"source": [
"# Transformer"
]
},
{
"cell_type": "code",
"metadata": {
"source_hash": "92b4cdc8",
"execution_start": 1711541889626,
"execution_millis": 33,
"deepnote_to_be_reexecuted": false,
"cell_id": "d1571b0984fb4fe8962f9b70742c85d0",
"deepnote_cell_type": "code",
"id": "qwrMo5CS3p0u"
},
"source": [
"class Transformer(nn.Module):\n",
" def __init__(self,src_pad_idx,trg_pad_idx,max_len,enc_vocab_size,dec_vocab_size,device,d_model,hidden_size,n_head,drop_prob,layer_num):\n",
" super(Transformer,self).__init__()\n",
" self.encoder=Encoder(max_len,enc_vocab_size,device,d_model,hidden_size,n_head,drop_prob,layer_num)\n",
" self.decoder=Decoder(max_len,dec_vocab_size,device,d_model,hidden_size,n_head,drop_prob,layer_num)\n",
" self.src_pad_idx = src_pad_idx\n",
" self.trg_pad_idx = trg_pad_idx\n",
" self.device=device\n",
"\n",
" def forward(self, src, trg):\n",
" src_mask = self.make_src_mask(src)\n",
" trg_mask = self.make_trg_mask(trg)\n",
" enc = self.encoder(src, src_mask)\n",
" output = self.decoder(enc, trg, src_mask, trg_mask)\n",
" return output # 确保返回output\n",
"\n",
" def make_src_mask(self, src):\n",
" src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)\n",
" return src_mask\n",
"\n",
" def make_trg_mask(self, trg):\n",
"\n",
" trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3) # 确保维度正确 (batch_size,n_head,seq_len,seq_len)\n",
" trg_len = trg.size(1)\n",
" trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).type(torch.bool) # (seq_len,seq_len)\n",
" trg_mask = trg_pad_mask & trg_sub_mask\n",
" return trg_mask"
],
"execution_count": 12,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"formattedRanges": [],
"cell_id": "3457c47436f8475fa5f3581c188bf041",
"deepnote_cell_type": "text-cell-h1",
"id": "k6qS8saW3p0u"
},
"source": [
"# Experiment"
]
},
{
"cell_type": "code",
"metadata": {
"source_hash": "bea0b6b2",
"execution_start": 1711541895993,
"execution_millis": 674,
"deepnote_to_be_reexecuted": false,
"cell_id": "f5664ef823ea43a89f3b0cf4fe187b33",
"deepnote_cell_type": "code",
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gKPW5rjg3p0v",
"outputId": "6e99a32b-02e9-4a88-b103-53f62ab0f9b0"
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"# from your_transformer_module import Transformer # 确保引入了你的 Transformer 类\n",
"\n",
"# 1. 准备数据\n",
"batch_size = 32\n",
"src_seq_len = 10\n",
"trg_seq_len = 15\n",
"src_pad_idx = 0\n",
"trg_pad_idx = 0\n",
"max_len = 20\n",
"enc_vocab_size = 1000\n",
"dec_vocab_size = 1000\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"d_model = 512\n",
"hidden_size = 2048\n",
"n_head = 8\n",
"drop_prob = 0.1\n",
"layer_num = 6\n",
"\n",
"src = torch.randint(1, enc_vocab_size, (batch_size, src_seq_len)).to(device)\n",
"trg = torch.randint(1, dec_vocab_size, (batch_size, trg_seq_len)).to(device)\n",
"\n",
"# 2. 初始化模型\n",
"model = Transformer(src_pad_idx, trg_pad_idx, max_len, enc_vocab_size, dec_vocab_size, device, d_model, hidden_size, n_head, drop_prob, layer_num).to(device)\n",
"\n",
"# 3. 前向传播测试\n",
"try:\n",
" output = model(src, trg[:, :-1]) # trg的输入去掉最后一个,以便于教师强制\n",
" assert output.shape == (batch_size, trg_seq_len - 1, dec_vocab_size), \"输出形状不匹配\"\n",
" print(\"前向传播测试通过!\")\n",
"except AssertionError as e:\n",
" print(f\"前向传播测试失败: {e}\")\n",
"\n",
"# 4. 掩码生成测试\n",
"src_mask = model.make_src_mask(src)\n",
"trg_mask = model.make_trg_mask(trg[:, :-1])\n",
"\n",
"try:\n",
" assert src_mask.shape == (batch_size, 1, 1, src_seq_len), \"源掩码形状不正确\"\n",
" assert trg_mask.shape == (batch_size, 1, trg_seq_len - 1, trg_seq_len - 1), \"目标掩码形状不正确\"\n",
" print(\"掩码生成测试通过!\")\n",
"except AssertionError as e:\n",
" print(f\"掩码生成测试失败: {e}\")\n"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"前向传播测试通过!\n",
"掩码生成测试通过!\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"source_hash": "b623e53d",
"deepnote_to_be_reexecuted": true,
"cell_id": "1d1a09a9e1a04e31b2a39298d6faad08",
"deepnote_cell_type": "code",
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "khDgNI_93p0v",
"outputId": "07c08084-06da-4ecc-8e23-4bd61a267539"
},
"source": [
"import torch\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"import numpy as np\n",
"\n",
"# 假设的模型参数\n",
"src_pad_idx = 0\n",
"trg_pad_idx = 0\n",
"max_len = 10\n",
"enc_vocab_size = 20 # 假设词汇量大小\n",
"dec_vocab_size = 20\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"d_model = 512\n",
"hidden_size = 2048\n",
"n_head = 8\n",
"drop_prob = 0.1\n",
"layer_num = 6\n",
"\n",
"# 初始化模型\n",
"model = Transformer(src_pad_idx, trg_pad_idx, max_len, enc_vocab_size, dec_vocab_size, device, d_model, hidden_size, n_head, drop_prob, layer_num)\n",
"model.to(device)\n",
"\n",
"# 生成简单的合成数据\n",
"def generate_data(batch_size, seq_len, vocab_size):\n",
" data = torch.randint(1, vocab_size, (batch_size, seq_len)) # 避免生成pad token\n",
" return data\n",
"\n",
"batch_size = 32\n",
"seq_len = 10\n",
"\n",
"# 生成数据\n",
"src_data = generate_data(batch_size, seq_len, enc_vocab_size)\n",
"trg_data = src_data.clone() # 输入和输出相同\n",
"\n",
"# 数据加载器\n",
"dataset = TensorDataset(src_data, trg_data)\n",
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
"\n",
"# 简单训练循环\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
"criterion = torch.nn.CrossEntropyLoss(ignore_index=trg_pad_idx)\n",
"\n",
"for epoch in range(100): # 训练一个epoch进行测试\n",
" model.train()\n",
" total_loss = 0\n",
" for src, trg in dataloader:\n",
" src, trg = src.to(device), trg.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(src, trg[:, :-1]) # 忽略序列的最后一个元素\n",
" output = output.reshape(-1, output.shape[-1])\n",
" trg = trg[:, 1:].reshape(-1) # 忽略序列的第一个元素\n",
" loss = criterion(output, trg)\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" print(f\"Epoch {epoch}, Loss: {total_loss / len(dataloader)}\")\n",
"\n",
"# 预期输出: 观察到的损失应该随着时间的推移而减少,表明模型正在学习复制任务\n"
],
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 0, Loss: 2.995344877243042\n",
"Epoch 1, Loss: 2.984524726867676\n",
"Epoch 2, Loss: 2.9938313961029053\n",
"Epoch 3, Loss: 2.9807004928588867\n",
"Epoch 4, Loss: 2.9887120723724365\n",
"Epoch 5, Loss: 2.998847246170044\n",
"Epoch 6, Loss: 2.9817023277282715\n",
"Epoch 7, Loss: 2.9803860187530518\n",
"Epoch 8, Loss: 2.9835989475250244\n",
"Epoch 9, Loss: 2.9836642742156982\n",
"Epoch 10, Loss: 2.9835991859436035\n",
"Epoch 11, Loss: 2.9806711673736572\n",
"Epoch 12, Loss: 2.9795029163360596\n",
"Epoch 13, Loss: 2.982041358947754\n",
"Epoch 14, Loss: 2.980454444885254\n",
"Epoch 15, Loss: 2.981685161590576\n",
"Epoch 16, Loss: 2.980349540710449\n",
"Epoch 17, Loss: 2.980302572250366\n",
"Epoch 18, Loss: 2.981624126434326\n",
"Epoch 19, Loss: 2.979287624359131\n",
"Epoch 20, Loss: 2.9816741943359375\n",
"Epoch 21, Loss: 2.9805760383605957\n",
"Epoch 22, Loss: 2.9801182746887207\n",
"Epoch 23, Loss: 2.9790592193603516\n",
"Epoch 24, Loss: 2.9804089069366455\n",
"Epoch 25, Loss: 2.980847120285034\n",
"Epoch 26, Loss: 2.97930645942688\n",
"Epoch 27, Loss: 2.9797005653381348\n",
"Epoch 28, Loss: 2.9782049655914307\n",
"Epoch 29, Loss: 2.979945182800293\n",
"Epoch 30, Loss: 2.977987051010132\n",
"Epoch 31, Loss: 2.9785232543945312\n",
"Epoch 32, Loss: 2.9788577556610107\n",
"Epoch 33, Loss: 2.9791769981384277\n",
"Epoch 34, Loss: 2.978320360183716\n",
"Epoch 35, Loss: 2.9786667823791504\n",
"Epoch 36, Loss: 2.978853464126587\n",
"Epoch 37, Loss: 2.978756904602051\n",
"Epoch 38, Loss: 2.9776389598846436\n",
"Epoch 39, Loss: 2.9791746139526367\n",
"Epoch 40, Loss: 2.978610038757324\n",
"Epoch 41, Loss: 2.980536460876465\n",
"Epoch 42, Loss: 2.979384422302246\n",
"Epoch 43, Loss: 2.9789557456970215\n",
"Epoch 44, Loss: 2.978343963623047\n",
"Epoch 45, Loss: 2.9789674282073975\n",
"Epoch 46, Loss: 2.9782252311706543\n",
"Epoch 47, Loss: 2.9781506061553955\n",
"Epoch 48, Loss: 2.9780097007751465\n",
"Epoch 49, Loss: 2.9773736000061035\n",
"Epoch 50, Loss: 2.9779841899871826\n",
"Epoch 51, Loss: 2.9753427505493164\n",
"Epoch 52, Loss: 2.979438543319702\n",
"Epoch 53, Loss: 2.9791324138641357\n",
"Epoch 54, Loss: 2.9781386852264404\n",
"Epoch 55, Loss: 2.9788925647735596\n",
"Epoch 56, Loss: 2.976363182067871\n",
"Epoch 57, Loss: 2.9784836769104004\n",
"Epoch 58, Loss: 2.9774186611175537\n",
"Epoch 59, Loss: 2.976188898086548\n",
"Epoch 60, Loss: 2.9777863025665283\n",
"Epoch 61, Loss: 2.9790074825286865\n",
"Epoch 62, Loss: 2.9771623611450195\n",
"Epoch 63, Loss: 2.978694438934326\n",
"Epoch 64, Loss: 2.977710723876953\n",
"Epoch 65, Loss: 2.977355718612671\n",
"Epoch 66, Loss: 2.9788594245910645\n",
"Epoch 67, Loss: 2.9797263145446777\n",
"Epoch 68, Loss: 2.9768307209014893\n",
"Epoch 69, Loss: 2.976647138595581\n",
"Epoch 70, Loss: 2.979998826980591\n",
"Epoch 71, Loss: 2.9781081676483154\n",
"Epoch 72, Loss: 2.9787964820861816\n",
"Epoch 73, Loss: 2.9788260459899902\n",
"Epoch 74, Loss: 2.977302074432373\n",
"Epoch 75, Loss: 2.9781789779663086\n",
"Epoch 76, Loss: 2.9773879051208496\n",
"Epoch 77, Loss: 2.97847318649292\n",
"Epoch 78, Loss: 2.9774057865142822\n",
"Epoch 79, Loss: 2.9770147800445557\n",
"Epoch 80, Loss: 2.9789209365844727\n",
"Epoch 81, Loss: 2.977285861968994\n",
"Epoch 82, Loss: 2.9775390625\n",
"Epoch 83, Loss: 2.979658365249634\n",
"Epoch 84, Loss: 2.9776415824890137\n",
"Epoch 85, Loss: 2.979929208755493\n",
"Epoch 86, Loss: 2.979032278060913\n",
"Epoch 87, Loss: 2.9790756702423096\n",
"Epoch 88, Loss: 2.9792330265045166\n",
"Epoch 89, Loss: 2.9780335426330566\n",
"Epoch 90, Loss: 2.9781672954559326\n",
"Epoch 91, Loss: 2.976395845413208\n",
"Epoch 92, Loss: 2.9773950576782227\n",
"Epoch 93, Loss: 2.9803178310394287\n",
"Epoch 94, Loss: 2.9772472381591797\n",
"Epoch 95, Loss: 2.9769339561462402\n",
"Epoch 96, Loss: 2.980050802230835\n",
"Epoch 97, Loss: 2.9778597354888916\n",
"Epoch 98, Loss: 2.977893829345703\n",
"Epoch 99, Loss: 2.977822780609131\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=7ee49933-c18c-41ac-a77e-19b1d981850e' target=\"_blank\">\n",
"<img alt='Created in deepnote.com' style='display:inline;max-height:16px;margin:0px;margin-right:7.5px;' src='data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiPz4KPHN2ZyB3aWR0aD0iODBweCIgaGVpZ2h0PSI4MHB4IiB2aWV3Qm94PSIwIDAgODAgODAiIHZlcnNpb249IjEuMSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiB4bWxuczp4bGluaz0iaHR0cDovL3d3dy53My5vcmcvMTk5OS94bGluayI+CiAgICA8IS0tIEdlbmVyYXRvcjogU2tldGNoIDU0LjEgKDc2NDkwKSAtIGh0dHBzOi8vc2tldGNoYXBwLmNvbSAtLT4KICAgIDx0aXRsZT5Hcm91cCAzPC90aXRsZT4KICAgIDxkZXNjPkNyZWF0ZWQgd2l0aCBTa2V0Y2guPC9kZXNjPgogICAgPGcgaWQ9IkxhbmRpbmciIHN0cm9rZT0ibm9uZSIgc3Ryb2tlLXdpZHRoPSIxIiBmaWxsPSJub25lIiBmaWxsLXJ1bGU9ImV2ZW5vZGQiPgogICAgICAgIDxnIGlkPSJBcnRib2FyZCIgdHJhbnNmb3JtPSJ0cmFuc2xhdGUoLTEyMzUuMDAwMDAwLCAtNzkuMDAwMDAwKSI+CiAgICAgICAgICAgIDxnIGlkPSJHcm91cC0zIiB0cmFuc2Zvcm09InRyYW5zbGF0ZSgxMjM1LjAwMDAwMCwgNzkuMDAwMDAwKSI+CiAgICAgICAgICAgICAgICA8cG9seWdvbiBpZD0iUGF0aC0yMCIgZmlsbD0iIzAyNjVCNCIgcG9pbnRzPSIyLjM3NjIzNzYyIDgwIDM4LjA0NzY2NjcgODAgNTcuODIxNzgyMiA3My44MDU3NTkyIDU3LjgyMTc4MjIgMzIuNzU5MjczOSAzOS4xNDAyMjc4IDMxLjY4MzE2ODMiPjwvcG9seWdvbj4KICAgICAgICAgICAgICAgIDxwYXRoIGQ9Ik0zNS4wMDc3MTgsODAgQzQyLjkwNjIwMDcsNzYuNDU0OTM1OCA0Ny41NjQ5MTY3LDcxLjU0MjI2NzEgNDguOTgzODY2LDY1LjI2MTk5MzkgQzUxLjExMjI4OTksNTUuODQxNTg0MiA0MS42NzcxNzk1LDQ5LjIxMjIyODQgMjUuNjIzOTg0Niw0OS4yMTIyMjg0IEMyNS40ODQ5Mjg5LDQ5LjEyNjg0NDggMjkuODI2MTI5Niw0My4yODM4MjQ4IDM4LjY0NzU4NjksMzEuNjgzMTY4MyBMNzIuODcxMjg3MSwzMi41NTQ0MjUgTDY1LjI4MDk3Myw2Ny42NzYzNDIxIEw1MS4xMTIyODk5LDc3LjM3NjE0NCBMMzUuMDA3NzE4LDgwIFoiIGlkPSJQYXRoLTIyIiBmaWxsPSIjMDAyODY4Ij48L3BhdGg+CiAgICAgICAgICAgICAgICA8cGF0aCBkPSJNMCwzNy43MzA0NDA1IEwyNy4xMTQ1MzcsMC4yNTcxMTE0MzYgQzYyLjM3MTUxMjMsLTEuOTkwNzE3MDEgODAsMTAuNTAwMzkyNyA4MCwzNy43MzA0NDA1IEM4MCw2NC45NjA0ODgyIDY0Ljc3NjUwMzgsNzkuMDUwMzQxNCAzNC4zMjk1MTEzLDgwIEM0Ny4wNTUzNDg5LDc3LjU2NzA4MDggNTMuNDE4MjY3Nyw3MC4zMTM2MTAzIDUzLjQxODI2NzcsNTguMjM5NTg4NSBDNTMuNDE4MjY3Nyw0MC4xMjg1NTU3IDM2LjMwMzk1NDQsMzcuNzMwNDQwNSAyNS4yMjc0MTcsMzcuNzMwNDQwNSBDMTcuODQzMDU4NiwzNy43MzA0NDA1IDkuNDMzOTE5NjYsMzcuNzMwNDQwNSAwLDM3LjczMDQ0MDUgWiIgaWQ9IlBhdGgtMTkiIGZpbGw9IiMzNzkzRUYiPjwvcGF0aD4KICAgICAgICAgICAgPC9nPgogICAgICAgIDwvZz4KICAgIDwvZz4KPC9zdmc+' > </img>\n",
"Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>"
],
"metadata": {
"created_in_deepnote_cell": true,
"deepnote_cell_type": "markdown",
"id": "lZ_LDbtZ3p0v"
}
}
],
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"deepnote_persisted_session": {
"createdAt": "2024-03-27T08:00:50.222Z"
},
"deepnote_full_width": true,
"deepnote_notebook_id": "15281b94ce5b48a38d2ff6c59c47cf1c",
"deepnote_execution_queue": [],
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"language_info": {
"name": "python"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment