Skip to content

Instantly share code, notes, and snippets.

@tianyu-lu
Created October 31, 2021 17:31
Show Gist options
  • Select an option

  • Save tianyu-lu/faf802a58b996061034fcf08f28d1502 to your computer and use it in GitHub Desktop.

Select an option

Save tianyu-lu/faf802a58b996061034fcf08f28d1502 to your computer and use it in GitHub Desktop.
1-Introduction.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "1-Introduction.ipynb",
"provenance": [],
"collapsed_sections": [
"DiSkj7uxUtPf",
"Rh-XDyt3XOwZ",
"RDlVKKA2ZOSL",
"cTbaUm19m1hV"
],
"toc_visible": true,
"authorship_tag": "ABX9TyON1YSaGONgl4NnJQJNCMCC",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/tianyu-lu/faf802a58b996061034fcf08f28d1502/1-introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "58FOqtF0mYyN"
},
"source": [
"## 1. Getting Started"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JLPtJRnviXGe"
},
"source": [
"Learning outcome: able to organize biological data in Python. Students should be able to manipulate and reshape n-dimensional arrays in PyTorch and Numpy, as well as familiarity with basic Numpy operations."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5r2fiauMm4f_"
},
"source": [
"### Basic Numpy"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sxWxO8QyWmA7"
},
"source": [
"import numpy as np"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lVghX5CbinOo",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f273b33f-ff0e-4071-a3dd-4703dbc4be87"
},
"source": [
"np.array([1,2,3])"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([1, 2, 3])"
]
},
"metadata": {
"tags": []
},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "qEGHUHyKiqAp",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "45480ebb-4a40-4870-8ca5-7a0da6174c1a"
},
"source": [
"np.ones((2,2))"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[1., 1.],\n",
" [1., 1.]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "F0Kd95crit0z",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b48466ff-f8c8-4679-b972-106386579354"
},
"source": [
"a = np.zeros((3,5))\n",
"a.shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(3, 5)"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "wIFVgCVliwO4",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "18eedb34-7b1d-4e9c-ccda-38e00101f134"
},
"source": [
"np.random.rand(3)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0.94533 , 0.4745626 , 0.96398767])"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "rkeIRTgzi7Ir",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4922851e-d335-4727-83d1-9b22cb73788d"
},
"source": [
"np.random.randn(3,3)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[-0.71945885, -0.49846205, 0.25290342],\n",
" [ 1.11512355, 0.59535581, 2.10845993],\n",
" [ 3.00792761, 0.67807833, -0.11648451]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "iWz18r1jjBqd",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "72bd7136-3f97-4fe1-b93e-54946eb4918c"
},
"source": [
"a = np.random.randint(10, size=(3,3))\n",
"print(a)\n",
"print(\"Sum 0: \", np.sum(a, axis=0))\n",
"print(\"Sum 1: \", np.sum(a, axis=1))\n",
"print(\"Mean 0: \", np.mean(a, axis=0))\n",
"print(\"Mean 1: \", np.mean(a, axis=1))\n",
"print(\"Std 0: \", np.std(a, axis=0))\n",
"print(\"Std 1: \", np.std(a, axis=1))\n",
"print(\"Frobenius norm: \", np.linalg.norm(a))\n",
"print(\"Norm 1\", np.linalg.norm(a, axis=1))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"[[9 7 8]\n",
" [5 9 4]\n",
" [8 9 6]]\n",
"Sum 0: [22 25 18]\n",
"Sum 1: [24 18 23]\n",
"Mean 0: [7.33333333 8.33333333 6. ]\n",
"Mean 1: [8. 6. 7.66666667]\n",
"Std 0: [1.69967317 0.94280904 1.63299316]\n",
"Std 1: [0.81649658 2.1602469 1.24721913]\n",
"Frobenius norm: 22.293496809607955\n",
"Norm 1 [13.92838828 11.04536102 13.45362405]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "VOQJlYQ_jReV",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "baaa91e3-9d09-4d5b-ec38-72d9fb76a508"
},
"source": [
"b = a.reshape(-1,1)\n",
"b.shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(9, 1)"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "F99em2I-ja6Y",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "6c7c96dd-67cc-4ac9-b00b-e4b53664f807"
},
"source": [
"c = b.squeeze()\n",
"c.shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(9,)"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "5WMlyYNljh2H",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "da28812c-41fb-478c-d85d-d0b0a69709d0"
},
"source": [
"sym = (a + a.T) / 2\n",
"sym"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[9. , 6. , 8. ],\n",
" [6. , 9. , 6.5],\n",
" [8. , 6.5, 6. ]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "byAt-QVDjrOO",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d1e9f293-b6a9-46e1-d385-6671f9d7e093"
},
"source": [
"vec = np.array([1,2,3])\n",
"np.dot(sym, vec)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([45. , 43.5, 39. ])"
]
},
"metadata": {
"tags": []
},
"execution_count": 12
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vRW-Z3qKjkbB"
},
"source": [
"### Basic PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bVgSQcqAjn8r"
},
"source": [
"A numpy array and a PyTorch tensor can store the same data, but they are not equivalent data types. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "PNpOFpkOj5QQ"
},
"source": [
"import torch"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "RoH9erzVj6vI",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 186
},
"outputId": "fc72046b-2c04-486c-8f75-b23cb4a5feae"
},
"source": [
"a = np.random.randn(3,3)\n",
"torch.sum(a)"
],
"execution_count": null,
"outputs": [
{
"output_type": "error",
"ename": "TypeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-14-f58adc1afa27>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0ma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m: sum(): argument 'input' (position 1) must be Tensor, not numpy.ndarray"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TEuPpiFmkQ-v",
"outputId": "85038b75-8dce-43ec-a440-5cda5bcef4d4"
},
"source": [
"a = torch.from_numpy(a)\n",
"torch.sum(a)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(0.2163, dtype=torch.float64)"
]
},
"metadata": {
"tags": []
},
"execution_count": 15
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "yaYT4gEmkXSG",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 458
},
"outputId": "c67f74bf-a1a2-4d07-df6a-4a54b1df4a57"
},
"source": [
"np.sum(a)"
],
"execution_count": null,
"outputs": [
{
"output_type": "error",
"ename": "TypeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-16-acdd4d3dd195>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36msum\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36msum\u001b[0;34m(a, axis, dtype, out, keepdims, initial, where)\u001b[0m\n\u001b[1;32m 2240\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2241\u001b[0m return _wrapreduction(a, np.add, 'sum', axis, dtype, out, keepdims=keepdims,\n\u001b[0;32m-> 2242\u001b[0;31m initial=initial, where=where)\n\u001b[0m\u001b[1;32m 2243\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2244\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36m_wrapreduction\u001b[0;34m(obj, ufunc, method, axis, dtype, out, **kwargs)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpasskwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpasskwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mufunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduce\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpasskwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: sum() received an invalid combination of arguments - got (out=NoneType, axis=NoneType, ), but expected one of:\n * (*, torch.dtype dtype)\n didn't match because some of the keywords were incorrect: out, axis\n * (tuple of names dim, bool keepdim, *, torch.dtype dtype)\n * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IFXfC0eFkctH",
"outputId": "1a9a9bcc-6e98-4484-d0fe-9bd68b8dbb65"
},
"source": [
"a = a.numpy()\n",
"np.sum(a)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.21633286730829915"
]
},
"metadata": {
"tags": []
},
"execution_count": 17
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "O4xessH9mlIn"
},
"source": [
"There are slight differences to do the equivalent numpy operations listed above in PyTorch."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8xrlX3YPnCBo",
"outputId": "d6f79291-5299-4f84-e019-ddcc9ec29646"
},
"source": [
"torch.zeros(2,1)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0.],\n",
" [0.]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SXseMnTLnPN-",
"outputId": "13184d93-6f3b-4285-adeb-5ce97df86547"
},
"source": [
"a = torch.randn(3,3)\n",
"a"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.2733, -0.2770, 2.9781],\n",
" [ 0.2537, 0.3480, -0.9627],\n",
" [-1.9232, -0.3513, -0.6138]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 19
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Kyk0sbWenP_u",
"outputId": "3a5de64e-457d-4508-ffb4-2a15eb425c9d"
},
"source": [
"b = a.reshape(-1,1)\n",
"b.shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([9, 1])"
]
},
"metadata": {
"tags": []
},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x_PHQANPngkT",
"outputId": "599feee2-4827-4fc8-87f3-e16c7b0aae1b"
},
"source": [
"c = a.flatten()\n",
"c.shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([9])"
]
},
"metadata": {
"tags": []
},
"execution_count": 25
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PDftZOLcoBTf"
},
"source": [
"Let's move on to matrix operations and neural networks!"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RaYaqqEHnoRH",
"outputId": "df9b5f7a-6129-435e-b2c2-d2b07cd22d1c"
},
"source": [
"I = torch.eye(3)\n",
"I @ a"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.2733, -0.2770, 2.9781],\n",
" [ 0.2537, 0.3480, -0.9627],\n",
" [-1.9232, -0.3513, -0.6138]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 26
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7j9LbSLdJI4D"
},
"source": [
"Usually we want to multiply a batch of matrices"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "208vq0AyoGPG",
"outputId": "89ca2726-5524-4508-f3bf-a7402f9055ea"
},
"source": [
"a = torch.randint(10, size=(2,4,3))\n",
"b = torch.randint(10, size=(2,3,1))\n",
"c = torch.bmm(a, b)\n",
"print(c.shape)\n",
"print(c)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"torch.Size([2, 4, 1])\n",
"tensor([[[ 5],\n",
" [17],\n",
" [17],\n",
" [28]],\n",
"\n",
" [[82],\n",
" [83],\n",
" [74],\n",
" [47]]])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "imSfDr4ZKPEf"
},
"source": [
"We're just multiplying matrices here, but to do machine learning you need to keep track of the forward pass operations to obtain gradients for backpropagation. PyTorch does that automatically for us."
]
},
{
"cell_type": "code",
"metadata": {
"id": "1A1rDKU2KlJg"
},
"source": [
"import torch.nn as nn"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MzpD2FZVKpr1",
"outputId": "f0097325-95df-4918-ffaa-6c63c1e607a3"
},
"source": [
"model = nn.Linear(10, 5)\n",
"a = torch.randn(2, 10) # 2 is the batch dimension\n",
"out = model(a)\n",
"out.shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([2, 5])"
]
},
"metadata": {
"tags": []
},
"execution_count": 29
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EhFEmJtdMGyl"
},
"source": [
"The objective function of a model needs to be a scalar. Here as a toy example we let it be the distance to the point $(1,1,1,1,1)$"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cZ5_AEliL7Jm",
"outputId": "d01571d5-6e0d-4a7b-d0ba-13af9afe8d7e"
},
"source": [
"loss_fn = lambda x: torch.norm(x-1)\n",
"loss = loss_fn(out)\n",
"loss"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(3.0192, grad_fn=<CopyBackwards>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 30
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zfq316pvMp_y"
},
"source": [
"Calling `loss.backward()` tells PyTorch to compute the gradients of the loss with respect to the model parameters, here it's a $10 \\times 5$ weight matrix and a $5$-element bias vector defined by the line `model = nn.Linear(10, 5)`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "yD7UIOCcM9ml"
},
"source": [
"model = nn.Linear(10, 5)\n",
"a = torch.randn(2, 10)\n",
"out = model(a)\n",
"loss = loss_fn(out)\n",
"loss.backward()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "xuyY6k6DNyKO"
},
"source": [
"We can look at the gradients by going through the model parameters."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uDjnd7_aN3B3",
"outputId": "222961d6-e529-48ca-b7cb-13a7333cb15c"
},
"source": [
"for p in model.parameters():\n",
" print(p.grad)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor([[-0.0590, 0.2874, -0.3905, -0.2997, -0.5172, 0.1235, 0.3597, 0.5080,\n",
" -0.4311, 0.3273],\n",
" [-0.0205, -0.0092, -0.3154, -0.3037, -0.5282, 0.1663, -0.0740, 0.3344,\n",
" -0.7991, 0.4118],\n",
" [-0.0061, -0.1590, -0.3520, -0.3769, -0.6577, 0.2267, -0.3071, 0.3265,\n",
" -1.1697, 0.5505],\n",
" [-0.0307, 0.0526, -0.3626, -0.3330, -0.5784, 0.1738, 0.0105, 0.4044,\n",
" -0.8006, 0.4348],\n",
" [-0.0120, -0.0174, -0.2052, -0.2005, -0.3489, 0.1113, -0.0652, 0.2140,\n",
" -0.5411, 0.2748]])\n",
"tensor([-0.4511, -0.6191, -0.8480, -0.6450, -0.4148])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VOmrKATROkMJ"
},
"source": [
"We can now minimize the loss function with gradient descent. In practice, we would use an optimizer like Adam, which we will do eventually. But fundamentally, this is how it works under the hood."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 279
},
"id": "hZt7AG00OcqG",
"outputId": "b7a2c476-eb2c-4c5e-eb1e-998c35f6042e"
},
"source": [
"model = nn.Linear(10, 5)\n",
"batch_size = 32\n",
"num_epochs = 100\n",
"learning_rate = 0.01\n",
"losses = []\n",
"for epoch_i in range(num_epochs):\n",
" data = torch.randn(batch_size, 10)\n",
" out = model(data)\n",
" loss = loss_fn(out)\n",
" losses.append(loss.item())\n",
" model.zero_grad() # clear the gradients before computing new ones\n",
" loss.backward()\n",
" for p in model.parameters():\n",
" p.data = p.data - learning_rate * p.grad\n",
"\n",
"import matplotlib.pyplot as plt\n",
"plt.plot(losses)\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.show()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEGCAYAAABiq/5QAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd5xU9b3/8ddnZnth+9JhAQEBBcRVQA02NMQYNdEUS9RoYnpMuXr1l/xS7u/eVG9uYoo3FtTEksQSS4wGY8UCsiBVinSWtrvAsgW2f35/zKiIrC7Lzpzdmffz8dgHM2dm57yPB9+c/e4532PujoiIJI9Q0AFERCS+VPwiIklGxS8ikmRU/CIiSUbFLyKSZFKCDtAVxcXFXlZWFnQMEZE+ZeHChTXuXnLw8j5R/GVlZVRUVAQdQ0SkTzGzTYdarqEeEZEko+IXEUkyKn4RkSQTs+I3s9lmVmVmyw9a/nUzW2VmK8zs57Fav4iIHFosj/jvAmYduMDMTgfOBya5+wTgphiuX0REDiFmxe/uLwK7D1r8ZeCn7t4cfU9VrNYvIiKHFu8x/jHAh8xsvpm9YGYndPZGM7vGzCrMrKK6ujqOEUVEElu8iz8FKASmAdcBfzUzO9Qb3f1Wdy939/KSkvdcf3BEahqaeXhRZY9+pohIXxHv4q8EHvaI14AOoDjOGfjVv9bw7b8u4c2d9fFetYhI4OJd/I8ApwOY2RggDaiJZ4Cm1nYeW7wNgLlvxnXVIiK9QixP57wfeBUYa2aVZnY1MBsYGT3F88/AFR7nW4DNeWMndU1tpKeEePFN/e5ARJJPzObqcfeLO3npslitsyseqNjC4PxMzhxXyl8rttDc1k56SjjISCIicZVUV+5uq93PS2truHDKYE4dU0JTawcLN+4JOpaISFwldPE3NrfR3Nb+9vOHF1XiDhcdP5RpI4tIDRsvapxfRJJMQhf/r595k9N/8Tx/fm0zre0dPLiwkqkjChlWlEV2egpThhUwV+P8IpJkErr4TxtTQkm/DG54eBmn/OxZNu7axyfLh779+owxJazYVkdNQ3OAKUVE4iuhi/+ko4p55Csncfvl5RRmp1Ock845xw54+/UPjY5cQvDyWg33iEjy6BN34DoSZsbM8f05c1wpLe0d7zqD55hBeRRkpfLCmmrOnzw4wJQiIvGT0Ef8BzKz95y2GQoZp4wuYe6bNcT5cgIRkcAkTfF35tQxJVTXN7NAp3WKSJJI+uL/6LEDKcpO45bn1wYdRUQkLpK++DPTwnzu5DKeW13NG9vqgo4jIhJzSV/8AJ+dXkZOegq3vLAu6CgiIjGn4gfyMlO5dNownli6jY01jUHHERGJKRV/1NWnjCAlHOIPL64POoqISEyp+KNKczP45PFDeGhhJVX1TUHHERGJGRX/AS6ZOoyW9g5eWbsr6CgiIjGj4j/A0QP6kZuewoKNu4OOIiISM7G8A9dsM6uK3m3r4Ne+Y2ZuZnG/3+77CYeM44YXsHCTLuYSkcQVyyP+u4BZBy80s6HA2cDmGK6728qHF7B6Zz1797cGHUVEJCZiVvzu/iJwqDGT/wGuB3rl5DjlZQW4w6LNOuoXkcQU1zF+Mzsf2OruS7rw3mvMrMLMKqqr43ezlMlD8wmHjAqN84tIgopb8ZtZFvB/gO935f3ufqu7l7t7eUlJSWzDHSArLYVjBvWjQpO2iUiCiucR/yhgBLDEzDYCQ4BFZjbgfb8rAMcPL2Txllpa2jqCjiIi0uPiVvzuvszdS929zN3LgEpgirvviFeGrjqhrIDmtg5WbNsbdBQRkR4Xy9M57wdeBcaaWaWZXR2rdfW048sKADTcIyIJKWa3XnT3iz/g9bJYrftIleZmMLwoi4pNu/kCI4OOIyLSo3TlbieOH15AxcY977olY31TK7fPXc9Hb57LC2vid6aRiEhPSvibrXdX+fBCHl60le8/uoKstDB1Ta08vmQ7Dc1tpIVD/PbZNzl1TPzONhIR6Skq/k6cOraEktx0HlxYieOEzDhrfH+uPmUE89bv4sf/WMXqHfWMHZAbdFQRkcOi4u/E4PxMFnx35iFfG1KQxU3/XMN98zfxo/OPiXMyEZEjozH+bijMTuOcYwfw8KKt7GtpCzqOiMhhUfF306XThlPf3MbjS7YFHUVE5LCo+LupfHgBY/vncs+8XjnJqIhIp1T83WRmXDptGMu27mVpZW3QcUREukzFfwQuOG4wWWlh7nhpQ9BRRES6TMV/BPplpHLZtOE8vmQbG2oag44jItIlKv4j9PkPjSA1HOJ3z60NOoqISJeo+I9QaW4Gl0wdxt9e38qW3fuCjiMi8oFU/D3gizNGETbj98+vCzqKiMgHUvH3gAF5GXzqhCE8uHAL22r3Bx1HROR9qfh7yJdOHYU73DRnddBRRETel4q/hwwpyOJLp47i4UVbeWr59qDjiIh0KpZ34JptZlVmtvyAZb8ws1VmttTM/mZm+bFafxCunTmaiUPyuOHhZeysawo6jojIIcXyiP8uYNZBy54GjnH3icAa4MYYrj/uUsMh/ufTk2lu7eDfHlhCR4d/8DeJiMRZzIrf3V8Edh+0bI67vzWd5TxgSKzWH5RRJTl879xxzH2zhnvmbwo6jojIewQ5xn8V8GRnL5rZNWZWYWYV1dV96zaHl5w4jBPKCpj90oZ33bpRRKQ3CKT4zey7QBtwb2fvcfdb3b3c3ctLSvrWLQ7NjAunDGHjrn2s2FYXdBwRkXeJe/Gb2ZXAucClnsCHw7OOGUBKyDRfv4j0OnEtfjObBVwPnOfuCT2/QX5WGjPGlPD3pdvfNdzzQMUWrntgCY8v2cbefa0BJhSRZBWze+6a2f3AaUCxmVUCPyByFk868LSZAcxz9y/FKkPQPjZpIM+uqmLR5lqOH17AhppGvvu35bS788DCSsIh4/SxJfz4E8dSmpsRdFwRSRIxK353v/gQi++I1fp6o5nj+pOWEuLxJduYMiyf7z+6nPSUEHO+PYNttU08s3Ins1/ewDm/fombPzOZk44qDjqyiCQBXbkbQ7kZqZwxtpQnlm3nsSXbmPtmDf/24bEMzMvk+OEFXD/raB796inkZaZw2R3z+d8XNMmbiMSeij/GPjZpENX1zVz/4FKOHZzHZdOGv+v1sQNyeexrp3D2+AH89MlVbNqlG7qISGyp+GPsjKNLyUoL09LewX99/BjCIXvPe7LTU/jheRMIh4z7XtPN20UktmI2xi8RmWlhvjVzDI4zcUjnUxMNyMvgrHH9eaCikm+fNYb0lHAcU4pIMtERfxx8YcZIrpkx6gPfd+m0YexubOHJZTvikEpEkpWKvxc5eVQxZUVZ3Ks5fkQkhlT8vUgoZFwydRgLNu5h1Q5N9SAisaHi72UuOn4oaSkh7p2nX/KKSGyo+HuZwuw0PnrsQB5aVEnFxt0f/A0iIodJxd8LffusMZTmpnPJbfN55PWtQccRkQSj4u+FhhZm8bevnMyU4fl88y+L+eWc1ZrXX0R6jIq/lyrITuOPV03lU+VDuPnZtdzx0oagI4lIgtAFXL1YWkqIn35iIvVNbfz4HysZ0z+XGWP61k1pRKT30RF/LxcKGTd9chJj+ufytfsWsbFGc/mIyJFR8fcB2ekp3HZ5OaGQ8fk/VrC/pT3oSCLSh6n4+4ihhVncdNEk1lY18OyqqqDjiEgfFrPiN7PZZlZlZssPWFZoZk+b2ZvRPwtitf5EdPrRpRRlp/HPFZrLR0S6L5ZH/HcBsw5adgPwjLuPBp6JPpcuCoeMmeP689yqKlraOoKOIyJ9VMyK391fBA6+9PR84O7o47uBC2K1/kR19oT+1De38er6XUFHEZE+Kt5j/P3dfXv08Q6gf5zX3+edfFQxWWlh5mi4R0S6KbBf7nrkUtROL0c1s2vMrMLMKqqrq+OYrHfLSA1z2tgSnn5jJx0duppXRA5fvIt/p5kNBIj+2enpKe5+q7uXu3t5SYkuWjrQhycMoKq+mde31AYdRUT6oHgX/2PAFdHHVwCPxnn9CeG0saWkhIw5b2i4R0QOXyxP57wfeBUYa2aVZnY18FPgLDN7E5gZfS6HKS8zlemjipizYqcmbxORwxazuXrc/eJOXjozVutMJh+eMIDvPbKcNTsbGDsgN+g4ItKH6MrdPmrWMQNIC4e4T/fnFZHDpOLvo4pz0jl34kAeXFhJfVNr0HFEpA9R8fdhV5xURmNLOw8trAw6ioj0ISr+PmzS0HwmD83nj69u0jn9ItJlKv4+7sqTylhf08jctTVBRxGRPkLF38edc+xAinPSufuVjUFHEZE+QsXfx6WlhLhk6jCeW13Fmzvrg44jIn2Aij8BXDZtGHmZqVx55wIq9+wLOo6I9HIq/gRQmpvBPVdPpb6plYtvm8f2vfuDjiQivZiKP0EcMziPP109ldrGVi6+dR4765qCjiQivZSKP4FMGprPXVedSHV9M5+9Yz57GluCjiQivZCKP8EcP7yA2684gY279nHFna/pql4ReY8uFb+ZZZtZKPp4jJmdZ2apsY0m3TV9VBG3XDqFN7bVcfXdFdz9yka+fv/rnPaL5/jLgs1BxxORgHX1iP9FIMPMBgNzgM8SuZm69FJnjuvPLz89mQUbd/ODx1bw2oZd1De18ad5mtRNJNl1dVpmc/d90Tn1f+/uPzezxbEMJkfuvEmDmDg4j3DIGFKQya0vrucnT65ia+1+BudnBh1PRALS1SN+M7PpwKXAE9Fl4dhEkp5UVpzN0MIszIyzJwwA4OmDbtT+6OKtPL+607tgikiC6WrxfxO4Efibu68ws5HAc7GLJbEwojibo0pzmPPGzreXVdc3c92DS7lpzuoAk4lIPHWp+N39BXc/z91/Fv0lb427f6O7KzWzb5nZCjNbbmb3m1lGdz9LDs/Z4/szf8NuavdFTvWc/fIGWto6WLm9nv0t7QGnE5F46OpZPfeZWT8zywaWA2+Y2XXdWWH0F8TfAMrd/RgiQ0af6c5nyeE7e8IA2juc51ZXUdfUyj2vbqI0N532Dmf5tr1BxxOROOjqUM94d68DLgCeBEYQObOnu1KATDNLAbKAbUfwWXIYJg7Oo3+/dOas2Mm98zZT39zGzy6aCMDizbUBpxOReOhq8adGz9u/AHjM3VuBbt35w923AjcBm4HtwF53n3Pw+8zsGjOrMLOK6urq7qxKDiEUMs4a35/nV1dzx0sb+NDoYk4fW8rg/EwWb1HxiySDrhb/H4CNQDbwopkNB+q6s0IzKwDOJ/JTwyAg28wuO/h97n6ru5e7e3lJSUl3ViWdOHv8APa3tlPT0MyXTxsFwORh+Sp+kSTR1V/u3uzug939HI/YBJzezXXOBDa4e3X0J4eHgZO6+VnSDdNGFpGbkcKkoflMH1kEwHFD89lau58qTe4mkvC6dAGXmeUBPwBmRBe9APwH0J3fBm4GpplZFrAfOBOo6MbnSDelpYT409VTKcpOw8wAOG5YPgCvb6nlw9Hz/UUkMXV1qGc2UA98KvpVB9zZnRW6+3zgQWARsCya4dbufJZ03+Sh+QwtzHr7+YRBeaSETMM9Ikmgq1M2jHL3Cw94/qMjmbLB3X9A5CcI6SUyUsOMG9hPZ/aIJIGuHvHvN7NT3npiZicTGaaRBDJ5aD5LK2tp73jnhK0DH4tIYuhq8X8J+J2ZbTSzjcBvgS/GLJUEYvLQfBpb2llb1cCexhY+9YdX+fjvX6ZD5S+SULo01OPuS4BJZtYv+rzOzL4JLI1lOImvydFf8D62ZCtPLtvB+ppGAJ5bXcWZ4/oHGU1EetBh3YHL3euiV/ACfDsGeSRAI4qyyctM5XfPrWNXYwv3f2Eag/IyuG3u+qCjiUgPOpJbL1qPpZBeIRQyTh1TwpCCTB768nSmjyricyePYN763Syr1Dw+IoniSIpfA78J6KZPTuLF607nqNJcAD594lBy0lN01C+SQN63+M2s3szqDvFVT2S6BUkwaSkhQqF3fpjrl5HKxScO5Yll29laqxO5RBLB+xa/u+e6e79DfOW6e1evAZA+7sqTRwBw50sbAk4iIj3hSIZ6JEkMzs/kYxMH8qd5m1i4aXfQcUTkCKn4pUv+77njGZSfydV3V7C2qiHoOCJyBFT80iVFOenc/bkTSQmFuGL2a+zULJ4ifZaKX7psWFEWd33uBGr3tXDZ7fNZvlWneIr0RSp+OSzHDM7jtsvL2bOvlfN++xI/fGwFdU2tQccSkcOg4pfDdtJRxTzznVO5bNpw7n51Ix/51VwamtuCjiUiXaTil27Jy0zlP84/htsvL2dr7X6eWbkz6Egi0kUqfjkip48tZWBeBo8v2RZ0FBHpokCK38zyzexBM1tlZivNbHoQOeTIhULGuRMH8sKaavbu01i/SF8Q1BH/r4Gn3P1oYBKwMqAc0gM+NmkQre3OP1fsCDqKiHRB3Is/euP2GcAdAO7e4u66318fduzgPIYXZfH4Ug33iPQFQRzxjwCqgTvN7HUzu93Msg9+k5ldY2YVZlZRXV0d/5TSZWbGxyYO4uW1NdQ0NAcdR0Q+QBDFnwJMAW5x9+OARuCGg9/k7re6e7m7l5eUlMQ7oxym8yYPosPhyWXbg44iIh8giOKvBCrdfX70+YNE/iGQPmxM/1zG9s/l8SUqfpHeLu7F7+47gC1mNja66EzgjXjnkJ73sUkDeW3jbl5ZVxN0FBF5H0Gd1fN14F4zWwpMBn4cUA7pQZdOHc7o0hyuumuByl+kFwuk+N19cXT8fqK7X+Due4LIIT2rIDuN+6+ZxrDCrEj5r1X5i/RGunJXelRxTjr3f2EawwuzufKuBXzzz6/zzxU7aGptDzqaiESp+KXHFeWkc98XpvKJ4wbz/JpqvvinhZzwn/9i3vpdQUcTEcDcPegMH6i8vNwrKiqCjiHd0Nrewfz1u/n+o8tpbGnjyWtnUJidFnQskaRgZgvdvfzg5Tril5hKDYc4ZXQxv7nkOPY0tnLdA0voCwcbIolMxS9xMWFQHjeeczTPrKrizpc3Bh1HJKmp+CVurjypjJnjSvnpk6tYvEXTM4kERcUvcWNm/PyiSZT2S+equxawvroh6EgiSUnFL3FVmJ3GH686EYDLZ79GVV1TwIlEko+KX+JuZEkOd155ArsbW7h89mu6WbtInKn4JRCThubzv5cdz+qd9dwxd0PQcUSSiopfAjNjTAnlwwv4l27ULhJXKn4J1Mxx/VmxrY5ttfuDjiKSNFT8Eqgzx/UH4Bkd9YvEjYpfAjWqJJsRxdk8vbIq6CgiSUPFL4EyM2aOK2Xeul00NLcFHUckKaj4JXAzx/Wnpb2DuWuqg44ikhQCK34zC5vZ62b296AySO9w/PAC8rNSeVrj/CJxEeQR/7XAygDXL71ESjjEGWNLeW5VFW3tHUHHEUl4gRS/mQ0BPgrcHsT6pfeZOb4/e/a18rOnVvGVexcy9cf/4hf/XBV0LJGEFNQR/6+A64FOD+/M7BozqzCziupqjf0muhljSshIDXHb3A0s3lxLSW46v39+Ha9t2B10NJGEE/c7cJnZucA57v4VMzsN+Dd3P/f9vkd34EoO66obSAuHGFqYRWNzGx/59VzM4MlrP0RWWkrQ8UT6nN50B66TgfPMbCPwZ+AMM7sngBzSy4wqyWFoYRYA2ekp/PyiiWzatY+fP7U64GQiiSXuxe/uN7r7EHcvAz4DPOvul8U7h/R+00YWceVJZdz1ykbdqF2kB+k8funVrp81lqGFmfzXEyt1r16RHhJo8bv78x80vi/JLSsthS/OGMWyrXup2LQn6DgiCUFH/NLrXThlCPlZqZq3X6SHqPil18tMC3PJicOY88YOtuzeF3QckT5PxS99wuXTywiZcefLG4OOItLnqfilTxiQl8G5Ewfy14ot1OsevSJHRMUvfcZVp4ygobmNvyzYEnQUkT5NxS99xsQh+UwfWcRNc1bzyrqaoOOI9FkqfulTfnPJcQwrzOKquxao/EW6ScUvfUpxTjr3fWHaO+W/VuUvcrhU/NLnvFX+wwuzufKuBTy1fEfQkUT6FBW/9EnFOen8+ZppTBjUj6/cu5B75m0KOpJIn6Hilz6rIDuN+z4/jdPHlvK9R5bz43+sZO9+neop8kFU/NKnZaaF+cNnj+fiE4dx64vrOeWnz/Lzp1ZR09AcdDSRXkvFL31eSjjETz5xLE984xRmjC3hlhfWMetXc6nThV4ih6Til4QxYVAev7tkCn/94nRqGpq5S9M7iBySil8SzgllhZw1vj+3zV2vMX+RQ1DxS0L65szR1De1MfslTeUscrC4F7+ZDTWz58zsDTNbYWbXxjuDJL4Jg/KYNWEAs1/awN59OuoXOVAQR/xtwHfcfTwwDfiqmY0PIIckuGtnjqa+uY3bX1ofdBSRXiWIm61vd/dF0cf1wEpgcLxzSOIbN7Af5xwbOerfvEs3cBF5S6Bj/GZWBhwHzD/Ea9eYWYWZVVRXV8c7miSIGz8yjlDI+Op9i2hqbQ86jkivEFjxm1kO8BDwTXevO/h1d7/V3cvdvbykpCT+ASUhDC3M4r8/OYllW/fyX0+sDDqOSK8QSPGbWSqR0r/X3R8OIoMkj7MnDOCaGSP507xNPLZkW9BxRAIXxFk9BtwBrHT3X8Z7/ZKcrvvwWMqHF3DDQ0tZvaM+6DgigQriiP9k4LPAGWa2OPp1TgA5JImkhkP89pIpZKen8IU/VrCnsSXoSCKBCeKsnpfc3dx9ortPjn79I945JPkMyMvgD589nh17m/jqfYtobe8IOpJIIHTlriSVKcMK+PEnjuWVdbv0y15JWilBBxCJt4uOH8LK7XXc8dIGThxRyDnHDgw6kkhc6YhfktINHzmaYwb34/8+spzdGu+XJKPil6SUGg7xi4smUdfUyg8eWxF0HJG4UvFL0ho3sB9fP2M0jy/Zphu2S1JR8UtS+/Jpoxg/sB/fe2Q5VXVNQccRiQsVvyS11HCI//7UJPa1tPGpP7zKlt2azE0Sn4pfkt64gf245/NT2d3Ywif/91XWVunKXklsKn4RIuf3/+WL02nrcJW/JDwVv0jUuIH9ePBL0zEzvvWXJbqyVxKWil/kAGXF2fznBcewbOtefv/cuqDjiMSEil/kIOccO5DzJw/iN8++yfKte4OOI9LjVPwih/Cj8yZQmJ3Gt/+6mOY23blLEouKX+QQ8rPS+NmFE1mzs4H/9/c3go4j0qNU/CKdOP3oUr44YyT3zNvMPfM2BR1H4qS9w3l13S52NTQHHSVmNDunyPu4ftbRrN5Zzw8fW8FRpTlMG1kEwN79rTS3tpOeGiY9JUR6SojIzeUOT3NbO1V1zQwpyHz7+1vaOnh+dRVLKmuZPrKYaSMLSQkf/jGau7Ng4x4eXbyV0twMLp02jOKc9He9p6m1nb8v3c79r22msbmNC6cM4RNTBlN00Pu6o6ahmRXb6ijMSqOsOIvcjNQj/syuamvvoK6pjcLstEO+3tTazvOrq3h57S5G98/hlKOKGV6UzRPLtvPrf61hXXUj/TJSuPGccXy6fCih0Pvv2737W9nX0kZxTjqp3dhXB2pt7zjiz/gg5u4xXcEhV2o2C/g1EAZud/efvt/7y8vLvaKiIi7ZRA5W19TKBb97mdp9rZw+tpTXN+9hfU3ju95TkpvOxMF5TByST3FuGvtb2tnX0k7d/lb27Guldl8LGWlhpo0s4qRRRYTMuP+1zTy4sJLdjS0UZKVy/PACCrPTmPPGTmr3tb792QVZqZw2tpT8rFTCZqSEQ2SnhcnNSCE7PQUzo8Odjg6nua2D/a3t7N3fyj+X72B9TSOZqWH2t7aTlhLi45MHM2FwP7bVNrG1dj8vrqlm7/5WRpZkk5eZyuuba0kNG5OH5pOWEiIU/ex9Le3sb2lnf2s7be1Oa3sHDuSmp5CbkUJuRipZaeFIHmBxZS3rq9/936goO42UsNHW7rR1ODnpKRRkp1KQlYY71De30dDUSl5mKkeV5nBUaQ6jS3MZOyCXgXkZNDS38eyqKp5avoM9+1qYOiLy33JEcTabdu9jQ3Uja3bWs6SylmVb99LU2sHIkmxmjC5hyvACGpraqKpvYn11I8+uqqKhuY30lBDNbZHTdrPTwjS2tDO2fy5XnVLGw4u2Mn/DbsqHF3DW+P6YgWGEQ0ZaSojUsLG+upFX1u1i+ba9uIMZFOekMygvg+FF2ZQVZ1OSk0Z7R2Sbm1rb2dXYwu7GFppa2xmYl8mQgkwy08Is2VLLgo172FDTSEluOqNKshlZksMV08sYOyC3W393zWyhu5e/Z3m8i9/MwsAa4CygElgAXOzunQ6kqvglaOurG7jwllcImXHcsAKOG5ZPXmYqTa3tNLW2s76mkaWVe1lX3cCB/0tlpoYpyEqlIDuNPY0tbNv7znxA4ZBx1rj+TB9VxPKte1m4aQ876pqYOa4/Hz9uMOVlBbyybhdPLtvOS2t30dzWHimQdqelC9cYlA8v4NMnDOWjEweyrbaJ2S9v4KGFlTS3dZAaNgbkZTBpSD6XTB3G9JFFmBlrdtbzlwVbWFa5l3Z3OtwxICsthay0MBmpYVLDkdIDaGhuo66pjfqmVva3tNPY0kZrmzN+UD9OKCtk0tA86va3sqFmH5t3N9Le4aSGQ6SEjPrmNvY0trB7Xythg5yMVLLTwuxubGFddQM1De9Ml52bnkJzWwct7R2U5qZT2i+dN7bV0XFQfaWnhDhmcB4Th+RRkpvO/PW7mb9hF02t7/z3KslN54yxpZw7aSDTRxaxtXY/L62tYfHmWk4dW8I5xwwkFDLcnYcWbeUn/1jJrk6m7k4Lh5g8LJ/pI4so7ZdOVV0zO+si/6hu3NXI1j3735MxJz2Fwuw00lNCbN/bRENzGwD5WamUDy9k3MBcduxtYn1NI+uqG/jDZcczNfqT5uHqTcU/Hfihu384+vxGAHf/SWffo+KX3qC1vYOUkL3vkE59Uyv7WtrJTAuTlRp+1xCNu7N59z5eWbeLxuY2zps0iNJ+Gd3K0tLWQUNzGw1NkdIIhSBkRnpKiMy0MBkp4UMOT9Q1tdLU0k5xTvoHDl8ErXZfC2t2NrB6Zz2rd9SRmRrmwxMGMGVYAYeW+yMAAAa6SURBVKGQsXdfK/M37GJHXRPDCrMYWZzDoPyM9wyLNbW2s766kfysVIpz0klLObxhlPYOp7mtHXfocKe9I/IPb0tbB0XZ6WSmhTv93ua2yE9fqaEQ4bBFhwXfeb+7s3d/K3X72xhSkHnIfeLu3RpGhN5V/BcBs9z989HnnwWmuvvXOvseFb+IyOHrrPh77Vk9ZnaNmVWYWUV1dXXQcUREEkYQxb8VGHrA8yHRZe/i7re6e7m7l5eUlMQtnIhIogui+BcAo81shJmlAZ8BHgsgh4hIUor7efzu3mZmXwP+SeR0ztnurpueiojESSAXcLn7P4B/BLFuEZFk12t/uSsiIrGh4hcRSTIqfhGRJBPIXD2Hy8yqge5Oj1gM1PRgnL4iGbc7GbcZknO7k3Gb4fC3e7i7v+d8+D5R/EfCzCoOdeVaokvG7U7GbYbk3O5k3Gboue3WUI+ISJJR8YuIJJlkKP5bgw4QkGTc7mTcZkjO7U7GbYYe2u6EH+MXEZF3S4YjfhEROYCKX0QkySR08ZvZLDNbbWZrzeyGoPPEgpkNNbPnzOwNM1thZtdGlxea2dNm9mb0z4Kgs/Y0Mwub2etm9vfo8xFmNj+6v/8Snf01oZhZvpk9aGarzGylmU1P9H1tZt+K/t1ebmb3m1lGIu5rM5ttZlVmtvyAZYfctxZxc3T7l5rZlMNZV8IWf/Tevr8DPgKMBy42s/HBpoqJNuA77j4emAZ8NbqdNwDPuPto4Jno80RzLbDygOc/A/7H3Y8C9gBXB5Iqtn4NPOXuRwOTiGx/wu5rMxsMfAMod/djiMzo+xkSc1/fBcw6aFln+/YjwOjo1zXALYezooQtfuBEYK27r3f3FuDPwPkBZ+px7r7d3RdFH9cTKYLBRLb17ujb7gYuCCZhbJjZEOCjwO3R5wacATwYfUsibnMeMAO4A8DdW9y9lgTf10RmEc40sxQgC9hOAu5rd38R2H3Q4s727fnAHz1iHpBvZgO7uq5ELv7BwJYDnldGlyUsMysDjgPmA/3dfXv0pR1A/4BixcqvgOuBjujzIqDW3duizxNxf48AqoE7o0Nct5tZNgm8r919K3ATsJlI4e8FFpL4+/otne3bI+q3RC7+pGJmOcBDwDfdve7A1zxyzm7CnLdrZucCVe6+MOgscZYCTAFucffjgEYOGtZJwH1dQOTodgQwCMjmvcMhSaEn920iF3+X7u2bCMwslUjp3+vuD0cX73zrR7/on1VB5YuBk4HzzGwjkSG8M4iMfedHhwMgMfd3JVDp7vOjzx8k8g9BIu/rmcAGd69291bgYSL7P9H39Vs627dH1G+JXPxJcW/f6Nj2HcBKd//lAS89BlwRfXwF8Gi8s8WKu9/o7kPcvYzIfn3W3S8FngMuir4tobYZwN13AFvMbGx00ZnAGyTwviYyxDPNzLKif9ff2uaE3tcH6GzfPgZcHj27Zxqw94AhoQ/m7gn7BZwDrAHWAd8NOk+MtvEUIj/+LQUWR7/OITLm/QzwJvAvoDDorDHa/tOAv0cfjwReA9YCDwDpQeeLwfZOBiqi+/sRoCDR9zXwI2AVsBz4E5CeiPsauJ/I7zFaifx0d3Vn+xYwImctrgOWETnrqcvr0pQNIiJJJpGHekRE5BBU/CIiSUbFLyKSZFT8IiJJRsUvIpJkVPyS1Mys3cwWH/DVYxOcmVnZgTMtivQWKR/8FpGEtt/dJwcdQiSedMQvcghmttHMfm5my8zsNTM7Krq8zMyejc6B/oyZDYsu729mfzOzJdGvk6IfFTaz26Lzyc8xs8zo+78RvYfCUjP7c0CbKUlKxS/JLvOgoZ5PH/DaXnc/FvgtkdlAAX4D3O3uE4F7gZujy28GXnD3SUTmz1kRXT4a+J27TwBqgQujy28Ajot+zpditXEih6IrdyWpmVmDu+ccYvlG4Ax3Xx+dBG+HuxeZWQ0w0N1bo8u3u3uxmVUDQ9y9+YDPKAOe9shNNDCzfwdS3f0/zewpoIHItAuPuHtDjDdV5G064hfpnHfy+HA0H/C4nXd+r/ZRInOtTAEWHDDTpEjMqfhFOvfpA/58Nfr4FSIzggJcCsyNPn4G+DK8fS/gvM4+1MxCwFB3fw74dyAPeM9PHSKxoqMMSXaZZrb4gOdPuftbp3QWmNlSIkftF0eXfZ3IHbCuI3I3rM9Fl18L3GpmVxM5sv8ykZkWDyUM3BP9x8GAmz1yC0WRuNAYv8ghRMf4y929JugsIj1NQz0iIklGR/wiIklGR/wiIklGxS8ikmRU/CIiSUbFLyKSZFT8IiJJ5v8DTimokpvCeyMAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WO2lzlMYS5a0"
},
"source": [
"Looks like the training was successful. Here's how we save a model and load it the next time we want to use it."
]
},
{
"cell_type": "code",
"metadata": {
"id": "XXCceQ2XTjsd",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "0f78b99f-1dac-4420-db74-495892b02830"
},
"source": [
"torch.save(model.state_dict(), \"mymodel.pt\")\n",
"\n",
"\n",
"model = nn.Linear(10, 5)\n",
"model.load_state_dict(torch.load(\"mymodel.pt\"))"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"metadata": {
"tags": []
},
"execution_count": 34
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o78DYfG0krg3"
},
"source": [
"A PyTorch tensor can be placed on the GPU. However, don't rush. GPU error messages can be less informative, so start running your model with small inputs on the CPU until there are no bugs, then use the GPU."
]
},
{
"cell_type": "code",
"metadata": {
"id": "T8u3zopPkwQX"
},
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "bT7hEQVqkz_P",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b365e36d-5a5c-4071-e55c-40604c0d1199"
},
"source": [
"a = np.random.randn(3,3)\n",
"a = torch.from_numpy(a)\n",
"a.to(device)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-0.8374, -0.3465, 0.0245],\n",
" [ 1.1270, -0.2721, -0.3044],\n",
" [ 0.3468, -1.5551, -1.1950]], dtype=torch.float64)"
]
},
"metadata": {
"tags": []
},
"execution_count": 36
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ygMj19C4lPCn"
},
"source": [
"If you have multiple GPUs, simply wrap the model with `nn.DataParallel()`, see [here](https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html) for details."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "79b6n2pZUT6c"
},
"source": [
"Exercise:\n",
"\n",
"1. Can you fix the error in the following code?\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "QbJfKCQsjmR7",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 369
},
"outputId": "eb524948-0974-44e5-d123-1e2e434f7085"
},
"source": [
"loss = nn.CrossEntropyLoss()\n",
"input = torch.randn(2, 3, 5, requires_grad=True)\n",
"target = torch.empty(2, 3, dtype=torch.long).random_(5)\n",
"output = loss(input, target)"
],
"execution_count": null,
"outputs": [
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-37-72fb2acc3b4e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequires_grad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m 960\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 961\u001b[0m return F.cross_entropy(input, target, weight=self.weight,\n\u001b[0;32m--> 962\u001b[0;31m ignore_index=self.ignore_index, reduction=self.reduction)\n\u001b[0m\u001b[1;32m 963\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 964\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mcross_entropy\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m 2466\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msize_average\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2467\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegacy_get_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2468\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mnll_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2469\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2470\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mnll_loss\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m 2272\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2273\u001b[0m raise ValueError('Expected target size {}, got {}'.format(\n\u001b[0;32m-> 2274\u001b[0;31m out_size, target.size()))\n\u001b[0m\u001b[1;32m 2275\u001b[0m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontiguous\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2276\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontiguous\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: Expected target size (2, 5), got torch.Size([2, 3])"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DiSkj7uxUtPf"
},
"source": [
"##### Solution"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RlTZxRSlUuq3",
"outputId": "771a4ce4-e668-44c0-cc1d-e79b586d4252"
},
"source": [
"loss = nn.CrossEntropyLoss()\n",
"input = torch.randn(2, 3, 5, requires_grad=True)\n",
"print(input)\n",
"input = input.flatten(end_dim=1)\n",
"print(input)\n",
"target = torch.empty(2, 3, dtype=torch.long).random_(5).flatten()\n",
"print(target.size())\n",
"output = loss(input, target)\n",
"print(output)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor([[[ 0.3482, 1.1633, -0.4005, 1.1513, -2.1177],\n",
" [-0.0121, -1.4946, 0.4862, 0.7363, -0.4808],\n",
" [ 0.0175, 0.2177, -0.6881, 0.7898, 1.0728]],\n",
"\n",
" [[ 0.6282, -1.9214, 0.0433, -0.7178, -0.0831],\n",
" [ 0.0064, 0.0601, -0.9912, -0.3535, 0.4251],\n",
" [ 0.2370, -1.0435, 0.2936, 0.2425, 0.0972]]], requires_grad=True)\n",
"tensor([[ 0.3482, 1.1633, -0.4005, 1.1513, -2.1177],\n",
" [-0.0121, -1.4946, 0.4862, 0.7363, -0.4808],\n",
" [ 0.0175, 0.2177, -0.6881, 0.7898, 1.0728],\n",
" [ 0.6282, -1.9214, 0.0433, -0.7178, -0.0831],\n",
" [ 0.0064, 0.0601, -0.9912, -0.3535, 0.4251],\n",
" [ 0.2370, -1.0435, 0.2936, 0.2425, 0.0972]], grad_fn=<ViewBackward>)\n",
"torch.Size([6])\n",
"tensor(1.4924, grad_fn=<NllLossBackward>)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6vdCcbeWj83k"
},
"source": [
"##### Exercises:\n",
"\n",
"\n",
"1. Sample 28 points in $\\mathbb{R}^2$, each element uniformly in $[0, 1)$. \n",
"2. Calculate distances between all pairs of points and store the result in a $28 \\times 28$ matrix.\n",
"3. Calculate the average distance between all points.\n",
"4. Calculate distances of each point to the origin.\n",
"5. Calculate the variance of distances to the origin.\n",
"6. Write a function to do steps 1 to 5 and returns the three results. Do these steps for 1000 uniformly sampled points in $\\mathbb{R}^{10}, \\mathbb{R}^{50}, \\mathbb{R}^{100}, \\mathbb{R}^{500}, \\mathbb{R}^{1000}$. Plot the averages with `matplotlib` using a log scale for the dimensions. What do you observe?\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rh-XDyt3XOwZ"
},
"source": [
"##### Solution"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 279
},
"id": "ABvS0tSFoeRL",
"outputId": "107b1927-5268-4a6a-df85-366f17e6408f"
},
"source": [
"def curse(n_samples, dim):\n",
" pts = np.random.rand(n_samples, dim)\n",
" pair_dists = np.zeros((n_samples, n_samples))\n",
" for i, p in enumerate(pts):\n",
" pair_dists[i,:] = np.linalg.norm(pts - p, axis=1)\n",
" dists_from_origin = np.linalg.norm(pts, axis=1)\n",
" return np.mean(pair_dists), np.mean(dists_from_origin)\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"pairwise_dists = []\n",
"dists_from_origin = []\n",
"dims = np.array([10, 50, 100, 500, 1000])\n",
"for dim in dims:\n",
" pair, mag = curse(1000, dim)\n",
" pairwise_dists.append(pair)\n",
" dists_from_origin.append(mag)\n",
"plt.figure(1, figsize=(12,4))\n",
"plt.subplot(121)\n",
"plt.plot(np.log(dims), pairwise_dists)\n",
"plt.xlabel(\"Dimensionality (log)\")\n",
"plt.ylabel(\"Average distance between random points\")\n",
"plt.subplot(122)\n",
"plt.plot(np.log(dims), dists_from_origin)\n",
"plt.xlabel(\"Dimensionality (log)\")\n",
"plt.ylabel(\"Average distance from origin\")\n",
"plt.show()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAs0AAAEGCAYAAACeiKhrAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd5hU5fnG8e9D731BegcVWASWIliwxF4RiUiUomKJJYnGEns0sfxMjA0FlWLDCkGNYgXUgNJ7Z5fe27IsLNue3x8z6ErY3WF3Z8+W+3Ndc82cMzPn3CQ6PJ7zvs9r7o6IiIiIiGSvTNABRERERESKOhXNIiIiIiK5UNEsIiIiIpILFc0iIiIiIrlQ0SwiIiIikotyQQeIRL169bxFixZBxxAROWZz5szZ6e4xQecoTPrNFpHiLLvf7WJRNLdo0YLZs2cHHUNE5JiZ2bqgMxQ2/WaLSHGW3e+2hmeIiIiIiORCRbOIiIiISC5UNIuIiIiI5EJFs4iIiIhILlQ0i4iIiIjkItei2cyuNLPq4dcPmNkEM+sa/WgiIiIiIkVDJFeaH3T3JDM7BTgbeB14ObqxRERERESKjkiK5ozw84XAKHf/D1AhepFERIqWpJQ0Hvl4CbuTU4OOIiIiuXB3npq8nBVbkwr0uJEUzZvMbCTwW+AzM6sY4fdEREqEpyevYNyMtazdlRx0FBERycXHCzbz8tQ1fL9qR4EeN5LidwDwBXCuu+8F6gB/LtAUIiJF1E/xu3jzx3UM69OSrs1qBx1HRERysDs5lUc/WUrnprUY2qdlgR47kqJ5pLtPcPdVAO6+BbimQFOIiBRBKWkZ3DthEU3rVObOc9oFHUdERHLx10+WkJSSxtNXxFK2jBXosSMpmjtk3TCzskC3Ak0hIlIE/evrVSTsTObJfrFUqVAu6DgiIpKDb5dv49/zN3NL3za0P656gR8/26LZzO4zsyQg1sz2hR9JwHZgUoEnEREpQhZtTOTV7+P5bVxT+rSpF3QcERHJQVJKGg9MXEzb+tW45YzWUTlHtkWzuz/h7tWB/3P3GuFHdXev6+73RSWNiEgRkJaRyd0fLaRu1Qr85cITgo4jIiK5eHryCrbsS+Gp/rFULFc2KufI9X6ju99nZo2B5lk/7+7fRSWRiEjARn0Xz7It+xh1TTdqVi4fdJx8M7PRwEXAdnfvGN73HtA+/JFawF53P+ko310LJBFqP5ru7nGFElpEJEIzE3YXyoTtXItmM3sSuApYyi89mx1Q0SwiJc7q7Uk89/UqLoxtyDkdjgs6TkEZC7wIvHF4h7v/9vBrM/sHkJjD989w951RSycikkcpaRnc+9FCmtSuzF3nRnfCdiQzWy4H2rv7oagmEREJWEamc/eHC6lSsSyPXNwh9y8UE+7+nZm1ONp7ZmaEWoueWZiZREQKwvPfrCJ+ZzJvXtcj6hO2I+meEQ8U//uTIiK5eHPGWuau38tDF51ITPWKQccpLKcC2w63FT0KB740szlmNjy7g5jZcDObbWazd+wo2AUFRESOZsnmREZ+F8+V3ZpwatuYqJ8vkpL8ADDfzL4Bfr7a7O63Ry2ViEgh27D7AE9/sYLT28VweZfGQccpTAOB8Tm8f4q7bzKz+sBXZrb8aHNa3H0UMAogLi7OoxNVRCQkPSOTuz9cSO0qFXjgwhML5ZyRFM0fhx8iIiWSu/OXiYsw4O/9OhEasVDymVk5oB859N53903h5+1mNhHogea0iEjAXv0+gSWb9/HyoK7UrFI4AyIi6Z4xLi8Hzma29v8BFwOpwBpgaHhpbhGRwHw4ZyPfr9rJY5d2oHGtykHHKUxnA8vdfePR3jSzqkAZd08Kvz4H+GthBhQROVL8jv386+uVnNuhAed3alho581pcZP3w8+LzGzhkY8Ijj0WOO+IfV8BHd09FlgJqN+ziARqe1IKj326lO4tajOoZ/Og40SFmY0HZgDtzWyjmV0XfusqjhiaYWaNzOyz8GYD4AczWwDMBP7j7pMLK7eIyJEyM517JyyiQrkyPHZpx0I9d05Xmu8IP1+UlwMfbba2u3+ZZfNHoH9eji0iUlAenrSElPRMnrwiljJlSuawDHcfmM3+IUfZtxm4IPw6Hugc1XAiIsfgnZnrmZmwm6eviKV+jUqFeu6cVgTcEn5eB6QAncKPg+F9+TUM+Dy7NzUTW0Si7fNFW/h88Vb+cHZbWsdUCzqOiIjkYEviQZ78fDl92tTlyrgmhX7+XFvOmdkAQrflriTUy/MnM8vXFWIzux9IB97O7jPuPsrd49w9LiYm+m1ERKR0STyQxoOTltChUQ1uOLVV0HFERCQH7s79ExeTnpnJE5fHBjJhO5LuGfcD3d19O4CZxQBfAx/m5YRmNoTQkI+z3F1tiUQkEI//Zyl7DqQydmh3ypeNpGW9iIgE5eMFm/l2+XYeuPAEmtWtEkiGSIrmMocL5rBdRLYoyv8ws/OAu4HT3f1AXo4hIpJf363cwQdzNvL7M1rTsXHNoOOIiEgOdien8ugnS+nctBZD+7QMLEckRfNkM/uCX2ZY/xb4LIfPAz/P1u4L1DOzjcDDhLplVCTUIB/gR3e/KQ+5RUTyJPlQOvdNWESrmKrcdmbboOOIiEgu/vrJEpJS0nj6iljKBjhhO5I+zX82s37AKeFdo9x9YgTfO9ps7dePMZ+ISIH6vy9WsDnxIB/ceDKVypcNOo6IiORgyvLt/Hv+Zu44qy3tj6seaJZIrjQDTAcygExgVvTiiIhEz5x1uxk3Yy3X9mpOXIs6QccREZEcJKWkcf/ERbStX41bzmgddJyIumdcT6h7xuWE+ir/aGbDoh1MRKQgpaRlcPeHC2lUszJ/Pu/4oOOIiEgunp68gi37UniqfywVywV/ZzCSK81/Brq4+y4AM6tL6Mrz6GgGExEpSC9NWc2aHcmMG9aDahUjvckmIiJBmJmwmzd/XMfQPi3o2qx20HGAyLpg7AKSsmwnhfeJiBQLSzfv4+Wpa7iiaxNOb6e+7yIiRVlKWgb3frSQJrUrc9c57YOO87NILresJrSgySTAgUuBhWb2JwB3/2cU84mI5Et6RiZ3f7SAWlXK8+BFJwQdR0REcvH8N6uI35nMm9f1oGoRujMYSZI14cdhk8LPwU5hFBGJwGs/JLB40z5GDOpKrSoVgo4jIiI5WLI5kZHfxdO/WxNObVu07gxG0nLu0cIIIiJS0OJ37OfZr1ZybocGnN/xuKDjiIhIDtIzMrnno4XUrlKBBy4sencGi841bxGRApSZ6dw7YREVy5XhsUs7El5QSUREiqjDdwZfLqJ3BlU0i0iJ9PbM9cxM2M3T/WOpX6NS0HEKhJm1I9TRqDlZfr/d/czAQomIFIBf3Rns1DDoOEelollESpxNew/y5GfLOKVNPa7s1iToOAXpA+AV4FVCC06JiBR7h+8MVgjfGSyqci2azawlcBvQgl9f2bgkerFERPLG3bl/4iIyHZ7o16mkDctId/eXgw4hIlKQxs8K3Rl86opORfrOYCRXmv8NvA58QmgZbRGRImvS/M1MXbGDhy46kaZ1qgQdp6B9Yma3ABOBQ4d3uvvu4CKJiOTdlsSDPPHZcnq3rsuAuKZBx8lRJEVzirs/H/UkIiL5tHP/IR79ZAldmtVicO8WQceJhsHh5z9n2edAqwCyiIjki7vzwMTFpGdm8mS/2CJ/ZzCSovk5M3sY+JJfX9mYG7VUIiJ58OgnS0k+lMHTV8RStkzR/vHNC3dvGXQGEZGC8snCLXyzfDsPXHgCzeoW/TuDkRTNnYBrgDP5ZXiGh7dFRIqEr5Zu45MFm7nzN+1o26Bkrb1kZme6+7dm1u9o77v7hMLOJCKSH7uTU3nk4yV0blqLoX2Kx/WASIrmK4FW7p4a7TAiInmReDCNB/69iOOPq86Np7cOOk40nA58C1x8lPccyLFoNrPRwEXAdnfvGN73CHADsCP8sb+4+2dH+e55wHNAWeA1d38yj38GEZGfPfbpUpJS0orVncFIiubFQC1ge5SziIjkyZOfL2NH0iFevTaOCuXKBB2nwLn7w+HnoXk8xFjgReCNI/Y/6+7PZPclMysLvAT8BtgIzDKzj919aR5ziIgwZfl2Js7bxO1ntaX9ccXnzmAkRXMtYLmZzeLXY5rVck5EAjd99U7Gz9zAjae1IrZJraDjRJWZ/ekouxOBOe4+P7vvuft3ZtYiD6fsAax29/jw+d8FLgVUNItInuw/lM79ExfRtn41fn9G8bozGEnR/HDUU4iI5MHB1AzunbCIFnWr8Iez2wUdpzDEhR+fhLcvAhYCN5nZB+7+9DEe71YzuxaYDdzp7nuOeL8xsCHL9kag59EOZGbDgeEAzZo1O8YYIlJaPD15OVv2pfDhTb2pWK5s0HGOSa73Md19GrAcqB5+LAvvExEJ1D+/WsH63Qd48opYKlcoXj++edQE6Orud7r7nUA3oD5wGjDkGI/1MtAaOAnYAvwjP8HcfZS7x7l7XExMTH4OJSIl1MyE3bwxYx1DeregW/PaQcc5ZrkWzWY2AJhJaELgAOAnM+sf7WAiIjmZv2Evr/+QwKCezejVqm7QcQpLfbIMkwPSgAbufvCI/bly923unuHumYSW5e5xlI9tArKuNtAkvE9E5JikpGVw70cLaVK7Mned0z7oOHkSyfCM+4Hu7r4dwMxigK+BD6MZTEQkO6npmdz94QIa1KjEvecfH3ScwvQ2oQsXk8LbFwPvmFlVjnGcsZk1dPct4c3LCU36PtIsoK2ZtSRULF8FXJ2n5CJSqr3w7SridybzxrAeVK0YSflZ9ESSuszhgjlsFxFcoRYRiZYRU1ezctt+Rg+Jo3ql8kHHKTTu/piZfQ70Ce+6yd1nh18Pyu57ZjYe6AvUM7ONhOaq9DWzkwi1rFsL3Bj+bCNCreUucPd0M7sV+IJQy7nR7r6k4P9kIlKSLdmcyCvT4unfrQmntSu+w7ciKZonm9kXwPjw9m+B/+nlKSJSGFZsTeKlKau59KRGnHl8g6DjFAozq+Hu+8ysDhAffhx+r467787p++4+8Ci7X8/ms5uBC7Jsf4Z+80Ukj9IzMrnno4XUrlKBBy48Ieg4+ZJr0ezufzazK/jlysYod58Y3VgiIv8rI9O5+6OFVK9UnocuOjHoOIXpHUKdMuYQujJ8mIW3WwURSkQkN6/9kMDiTfsYMagrtapUCDpOvkQ0qMTdPwI+OpYDZ7MCVR3gPaAFoduBA47S4khE5KjG/DeBBRv28txVJ1G3WsWg4xQad7/IzAw43d3XB51HRCQSCTuTefarlZzboQHndzwu6Dj5lu3YZDNLMrN92T0iOPZY4Lwj9t0LfOPubYFvwtsiIrlatyuZZ75cwdkn1OeSzo2CjlPo3N2B/wSdQ0QkEpmZzr0fLaRCuTI8dmlHQv/dX7xle6XZ3asDmNljhHp4vknoVuAgoGFuB85mBapLCU1GARgHTAXuObbIIlLauDv3frSI8mXK8NhlJePHN4/mmll3d58VdBARkZyMn7WenxJ289QVnahfo1LQcQpEJMMzLnH3zlm2XzazBcBDeThfgywtjrYC2c7i0epSInLYe7M2MCN+F3+/vBMNa1YOOk6QegKDzGwdkEx4TLO7xwYbS0TkF1sSD/LEZ8vp3bouA+Ka5v6FYiKSojnZzAYB7xKacDKQ0I91vri7m5nn8P4oYBRAXFxctp8TkZJta2IKf/vPMnq1qsNV3UvOj28enRt0ABGRnLg7D0xcTHpmJk/2iy1RdwYj6bd8NaGVALeFH1eS9+b228ysIYQa6wPbc/m8iJRi7s4D/15Makbox7dMmZLz45sX7r4OqEVoUZOLgVrhfSIiRcInC7fwzfLt3HVOe5rVrRJ0nAKVa9Hs7mvd/VJ3r+fuMe5+mbuvzeP5PgYGh18PBibl8FkRKeU+XbiFr5dt485z2tGiXtWg4wTOzO4gtCpg/fDjLTO7LdhUIiIhu5NTeeTjJXRuWouhfVoGHafA5To8I7xs9g2E2sT9/Hl3H5bL9462AtWTwPtmdh2wjtAVbBGR//Hzj2+TmgwrgT++eXQd0NPdkwHM7ClgBvBCoKlERIDHPl3KvoNpPHVFJ8qWwDuDkYxpngR8D3wNZER64GxWoAI4K9JjiEjp9dinS0k8mMbbN/SkXNlIRpKVCsavf4czwvtERAI1ZcV2Js7bxO1nteX442oEHScqIimaq7i72sKJSKGZsrzk//jm0RjgJzM7vCrrZWSzHLaISGHZfyid+ycsom39avz+jNZBx4maSC7ffGpmF0Q9iYgIkJSSxv0TS/6Pb164+z+BocDu8GOou/8r2FQiUto9PXk5W/al8OQVsVQsVzboOFETyZXmO4C/mNkhII1f+oLq8o+IFLinJ69gy74UPrq5d4n+8c0rd58LzA06h4gIwKy1u3ljxjqG9mlBt+a1g44TVbkWzYdXBhQRibaf4nfx5o/rGNanJV2blewfXxGR4i4lLYN7PlpI41qVueuc9kHHibpIrjRjZrWBtsDP6yC6+3fRCiUipU9KWgb3TlhE0zqVuevcdkHHERGRXLzw7SridyTzxrAeVK0YUUlZrEXScu56QkM0mgDzgV6EWhydGd1oIlKa/OvrVSTsTObt63tSpULJ//HNDzOrwa9bgO4OMI6IlEJLNifyyrR4+ndrwmntYoKOUygimQh4B9AdWOfuZwBdgL1RTSUipcqijYm8+n08v41rSp829YKOU2SZ2Y1mthVYCMwJP2YHm0pESpv0jEzu+WghtatU4IELTwg6TqGJ5HJOirunmBlmVtHdl5tZyR+4IiKFIi0jk7s/WkjdqhX4Syn68c2ju4CO7r4z6CAiUnq99kMCizftY8SgrtSqUiHoOIUmkqJ5o5nVAv4NfGVmewit5icikm8jp61h2ZZ9jLymGzUrlw86TlG3BjgQdAgRKb0Sdibz7FcrObdDA87veFzQcQpVJN0zLg+/fMTMpgA1gclRTSUipcK46Wv5x1cruTC2Ied2KF0/vnl0HzDdzH4CDh3e6e63BxdJREqLzXsPMvyN2VQoV4a/XtoRs9K1IGmORbOZlQWWuPvxAO4+rVBSiUiJlpnpPPXFckZOi+fsExrwTP/OQUcqLkYC3wKLgMyAs4hIKbJsyz6GjJnJgUMZjLy2Gw1qVMr9SyVMjkWzu2eY2Qoza+bu6wsrlIiUXKnpmdz94QL+PX8zg3o249FLOlCubCRzkgUo7+5/OtYvmdlo4CJgu7t3DO/7P+BiIJXQsI+h7v4/k7zNbC2QBGQA6e4el/f4IlIcTV+9kxvfnEPViuX44OaTOf640rm+XSR/U9UGlpjZN2b28eFHtIOJSMmzLyWNoWNn8u/5m/nzue15/LKOKpiPzedmNtzMGppZncOPCL43FjjviH1fEZpUGAusJDT0IztnuPtJKphFSp9J8zcxeMxMGtaqxIRbepfaghkimwj4YNRTiEiJtzUxhSFjZrJ6+36eubIz/bs1CTpScTQw/Jy1wHWgVU5fcvfvzKzFEfu+zLL5I9C/APKJSAnh7rwyLZ6nJi+nV6s6jLwmrtRP1o5kIqDGMYtIvqzalsTg0TNJPJjG6CHdS00j/ILm7i2jdOhhwHvZnRb40swcGOnuo472ITMbDgwHaNasWVRCikjhyMh0Hv1kCW/MWMfFnRvxzJWxVCxXNuhYgdOyWyISVTMTdnP9uFlULF+W9248mY6NawYdqdgys/LAzcBp4V1TCRWyafk45v1AOvB2Nh85xd03mVl9Qm1Hl7v7d0d+KFxMjwKIi4vzvOYRkWClpGVwx7vz+GLJNoaf1op7zzueMmVKV5eM7KhoFpGo+WzRFv7w3nya1K7MuKE9aFqnStCRiruXgfLAiPD2NeF91+flYGY2hNAEwbPc/aiFrrtvCj9vN7OJQA/gf4pmESn+9iSnct24WczbsJeHLz6RoX2idXOreFLRLCJRMfqHBB77z1K6NqvNa9fGUbtq6Vk1Koq6u3vW/nzfmtmCvBzIzM4D7gZOd/ejLphiZlWBMu6eFH59DvDXvJxPRIq2DbsPMHj0TDbuPciIq7tyfqeGQUcqcnItms2sD/AI0Dz8eQPc3XOceCIipVNmpvPE58t49fsEzu3QgOeu6kKl8hoLV0AyzKy1u68BMLNWhFrB5cjMxgN9gXpmthF4mNBkwoqEhlwA/OjuN5lZI+A1d78AaABMDL9fDnjH3bW4lUgJs2hjIkPHziItI5O3r+9J9xaRNOUpfSK50vw68EdgDhH8OItI6XUoPYO7PljIJws2c+3JzXn44g6U1Vi4gnQXMMXM4gldwGgODM3tS+4+8Ci7X8/ms5uBC8Kv4wGtPCNSgk1dsZ1b3p5L7SoVeHd4T9rUrx50pCIrkqI50d0/j3oSESnWEg+mceObs/kxfjf3nHc8N53eqtQtsRpN4RVaOwNtgfbh3Svc/VD23xIRyd77szdw34RFtG9QnbFDu1O/FK7ydywiKZqnhFeOmgD8/OPs7nOjlkpEipUtiQcZMnoW8Tv38+xvO3N5F/VgLmjhFVoHuvuzwMKg84hI8eXuPP/Nap79eiWntq3HiEFdqV6pdPdgjkQkRXPP8HPWlaAcOLPg44hIcbNiaxJDxswkKSWdMUN6cErbekFHKsn+a2YvEuqpnHx4py5iiEik0jMyeXDSYsbP3EC/ro156opYymtl1ohEsrjJGYURRESKnxlrdjH8zdlULl+W927sRYdG6sEcZSeFn7N2sNBFDBGJyIHUdG59Zx7fLt/OrWe04c5z2mkY3TGIpHtGA+DvQCN3P9/MTgROdvejTiIRkdLhkwWbufP9BTSrW4WxQ7vTpLZ6MEeLmd3h7s8BD7r7D0HnEZHiZ+f+QwwbO4vFmxL52+UdGdSzedCRip1IrsePBb4AGoW3VwJ/yM9JzeyPZrbEzBab2Xgz08hzkWLkte/juW38PE5qWosPbzpZBXP0He6Q8XygKUSkWErYmUy/EdNZuS2JUdfEqWDOo0jGNNdz9/fN7D4Ad083szy3njOzxsDtwInuftDM3geuIlSci0gRlpnpPP6fZYz+bwIXdDqOfw44ST2YC8cyM1sFNDKzrJMAD/fNjw0ol4gUcfPW7+G6cbMBGH9DL7o0qx1wouIrkqI52czqEho3h5n1AhIL4LyVzSwNqAJszufxRCTKUtIyuPP9Bfxn0RaG9G7BgxedqB7MhcTdB5rZcYTu+l0SdB4RKR6+WrqN28bPpUGNSowd2oOW9aoGHalYi6Ro/hPwMdDazP4LxAD983pCd99kZs8A64GDwJfu/uWRnzOz4cBwgGbNmuX1dCJSABIPpHHDm7OZmbCb+y84getPbanJI4XM3beihUZEJEJv/biOhyYtplPjmrw+pDv1qlUMOlKxF0n3jLlmdjqhZvpGqJl+Wl5PaGa1gUuBlsBe4AMz+527v3XEeUcBowDi4uI8r+cTkfzZvPcgg0fPZO2uZJ676iQuPalx0JFERCQb7s4zX67gpSlrOPP4+rx4dReqVIjkGqnkJpLuGVUIXW1u7u43mFlbM2vv7p/m8ZxnAwnuviN8/AlAb+CtHL8lIoVu2ZZ9DBkzkwOHMhg3rAe9W6sHs4hIUZWansm9ExYyYe4mBvZoymOXdqScejAXmEj+lxwDpAInh7c3AY/n45zrgV5mVsVC93fPApbl43giEgXTV+9kwCszMIwPbj5ZBXMREr6YISLys6SUNK4bN4sJczfxp9+04++Xd1LBXMAi+V+ztbs/DaQBuPsBQsM08sTdfwI+BOYCi8IZRuX1eCJS8CbN38TgMTNpWKsSE27pzfHH1Qg6kgBm1tvMlgLLw9udzWxEwLFEJGDb9qXw25E/Mn3NLp7uH8vtZ7XVvJMoiGSQS6qZVeaX7hmtgUP5Oam7Pww8nJ9jiEjBc3dGfRfPE58vp2fLOoy6No6alcsHHUt+8SxwLqHJ2bj7AjM7LdhIIhKk1duTGDx6FnsOpDJ6SHdObxcTdKQSK5Ki+RFgMtDUzN4G+gBDophJRAKQkek89ulSxk5fy4WxDfnngM5ULKcezEWNu2844gpSnvvmi0jxNjNhN9ePm0WFcmV5/8aT6di4ZtCRSrRIumd8aWZzgF6EhmXc4e47o55MRApNSloGf3h3PpOXbOX6U1rylwtOoIx6MBdFG8ysN+BmVh64A80JESmVPlu0hT+8N58mtSszbmgPmtbRVIdoi6R7xlvANOB7d18e/UgiUpj2HkjlhjdmM2vtHh648ASuP7VV0JEkezcBzwGNCU3K/hL4faCJRKTQvf5DAo//Zyldm9XmtWvjqF21QtCRSoVIhme8DpwKvBAezzwP+M7dn4tqMhGJuo17DjBkzCzW7zrAi1d34aLYRkFHkhyE7/INCjqHiAQjM9P5+2fLeO2HBM7t0IDnrupCpfIaRldYcu2e4e5TgL8BDwKvAnHAzVHOJSJRtmRzIv1GTGfbvhTeuK6HCuZiwMzGmVmtLNu1zWx0kJlEpHAcSs/g9nfn8doPCQw+uTkjBnVTwVzIIhme8Q1QFZgBfA90d/ft0Q4mItHz/aod3PzWXKpXKsdHN/emXYPqQUeSyMS6+97DG+6+x8y6BBlIRKIv8UAaw9+czU8Ju7n3/OO58bRWaikXgEiGZywEugEdgURgr5nNcPeDUU0mIlExYe5G7v5wIW3qV2PM0O40rFk56EgSuTJmVtvd9wCYWR0i+x0XkWJq896DDBkzk4SdyTx31UlcelLjoCOVWpEMz/iju58G9AN2EVohcG/O3xKRosbdGTF1NX96fwHdW9Th/ZtOVsFc/PwDmGFmj5nZ48B04OncvmRmo81su5ktzrKvjpl9ZWarws+1s/nu4PBnVpnZ4AL7k4hIrpZt2cflI/7Llr0pjBvaQwVzwHItms3sNjN7j9AEwEuB0cD50Q4mIgUnI9N5aNISnp68gks6N2LssO7UqKRFS4obd38DuALYBmwF+rn7mxF8dSxw3hH77gW+cfe2wDfh7V8JX8l+GOgJ9AAezq64FpGCNX31Tga8MgPDeP+mk+ndpl7QkUq9SG7rVQT+Ccxx9/Qo5xGRApaSlsHt4+fx5dJt3HhaKwkl7U4AACAASURBVO4573j1YC7elgN7CP9+m1kzd1+f0xfc/Tsza3HE7kuBvuHX44CpwD1HfOZc4Ct33x0+11eEiu/xeU4vIrmaNH8Td32wgJb1qjJ2aA8a1dJdwaIgkqK5s7s/k3WHmb3p7tdEKZOIFJA9yalcN24W8zbs5eGLT2Ron5ZBR5J8MLPbCF353UZoJUADHIjNw+EauPuW8OutQIOjfKYxsCHL9sbwvqNlGw4MB2jWrFke4oiIu/PKtHiemrycni3rMOraOGpW1l3BoiKSorlD1g0zK0doYqCIFGEbdh9g8JiZbNxzkBFXd+X8Tg2DjiT5dwfQ3t13FeRB3d3NzPN5jFHAKIC4uLh8HUukNMrIdB79ZAlvzFjHRbEN+ceAzlQsp5ZyRUm2Y5rN7D4zSwJizWyfmSWFt7cBkwotoYgcs8WbErl8xHR27U/lret6qmAuOTYQ6mJUELaZWUOA8PPRWoluAppm2W4S3iciBSglLYNb3p7DGzPWccOpLXn+qi4qmIugbK80u/sTwBNm9oS731eImUQkH6at3MEtb82hVpUKvDu8J23qqwdzCRIPTDWz/wCHDu9093/m4VgfA4OBJ8PPR7sY8gXw9yyT/84B9PeBSAHKOozuoYtOZNgpGkZXVEUyPON+M/sd0NLdHzOzpkBDd58Z5Wwicow+mL2B+yYsom2D6owd2p0GNSoFHUkK1vrwo0L4EREzG09o0l89M9tIaFz0k8D7ZnYdsA4YEP5sHHCTu1/v7rvN7DFgVvhQfz08KVBE8m/D7gMMHj2TjXsP8tLVXblAdwWLtEiK5peATOBM4DFgf3hf9yjmEpFj4O68NGU1z3y5kj5t6vLK77pRXS3lShx3fzSP3xuYzVtnHeWzs4Hrs2yPJtRqVEQK0KKNiQwdO4u0jEzeuq4nPVrWCTqS5CKSormnu3c1s3nw87KtEV/hEJHoSs/I5KGPl/DOT+u5vEtjnroilgrlcm3BLsWQmcUAdxOaoP3zbQR3PzOwUCJyzKau2M4tb8+ltobRFSuR/M2aZmZlCbU1OvyjnRnVVCISkYOpGdz01hze+Wk9N/dtzT8HdFbBXLK9TahPc0vgUWAtvwydEJFi4P3ZG7hu3Gxa1K3KxFt6q2AuRiK50vw8MBFoYGZ/A/oDD0Q1lYjkatf+Q1w3bjYLNu7lr5d24NqTWwQdSaKvrru/bmZ3uPs0YJqZqWgWKQbcnee/Wc2zX6/k1Lb1GDGoq4bRFTO5Fs3u/raZzeGXsW+Xufuy6MYSkZys25XMkDGz2Lz3IC8P6sZ5HY8LOpIUjrTw8xYzuxDYDGggpEgRl56RyYOTFjN+5gb6dQ0NoytfVncFi5tIrjQDVAEOD9HQWo4iAVq4cS/Dxs4iPdN554aedGuumqkUedzMagJ3Ai8ANYA/BBtJRHJyIDWdW9+Zx7fLt3PrGW2485x2mFnQsSQPci2azewh4ErgI0JLto4xsw/c/fFohxORX5uyYju/f3sudapWYOzQHrSpXy3oSFK49rh7IqEFTs4AMLM+wUYSkezs3H+IYWNnsXhTIn+7vCODejYPOpLkQyRXmgcBnd09BcDMngTmAyqaRQrR+7M2cN/ERRx/XHXGDO1O/erqwVwKvQB0jWCfiAQsYWcyg0fPZHtSCqOuiePsExsEHUnyKZKieTOh1kYp4e2KaBlVkUJz5OSRl3/XjWoVIx1ZJSWBmZ0M9AZizOxPWd6qQWjonIgUIfPW7+G6cbMBGH9DL7o0q53LN6Q4yPZvXjN7gdAY5kRgiZl9Fd7+DaDVAEUKgSaPSFgFoBqh3+ys/an2EepoJCJFxFdLt3Hb+Lk0qFGJsUN70LJe1aAjSQHJ6XLV7PDzHEIt5w6bmt+Tmlkt4DWgI6FCfJi7z8jvcUVKEk0ekcOytJcb6+7rAMysDFDN3fcFm05EDnvrx3U8NGkxnRrX5PUh3alXrWLQkaQAZVs0u/u4KJ73OWCyu/cPry5YJYrnEil2du4/xHVjZ7FoUyKPX9aR3/XS5BEB4AkzuwnIILSoSQ0ze87d/y/gXCKlmrvzzJcreGnKGs48vj4vXt2FKhU0jK6kKfT7vOF2SacBrwO4e6q77y3sHCJF1dqdyVzx8nRWbEti5DVxKpglqxPDV5YvAz4ntDLgNcFGEindUtMzufODBbw0ZQ0DezRl1DXdVDCXUEH8v9oS2EGodV1nQsM/7nD35KwfMrPhwHCAZs2aFXpIkSDM37CX68bOItOdd27oRVdNHpFfK29m5QkVzS+6e5qZedChREqrpJQ0bnl7Lt+v2smfftOO285so2F0JVjEV5rNrKCGUJQj1B7pZXfvAiQD9x75IXcf5e5x7h4XExNTQKcWKbq+WbaNgaN+pErFsnx0c28VzHI0I4G1QFXgOzNrTmgyoIgUsm37Uhgw8kemr9nF0/1juf2stiqYS7hci2Yz621mS4Hl4e3OZjYiH+fcCGx095/C2x+iHqNSyo2fuZ4b3phNm/rVmHBzH1rFaNES+V/u/ry7N3b3CzxkHeFFTkSk8KzalkS/EdNZtyuZ0UO6MyCuadCRpBBEMjzjWeBc4GMAd19gZqfl9YTuvtXMNphZe3dfAZwFLM3r8USKM3fn2a9W8vy3q+nbPoaXru5KVfVgliOY2e/c/a0jejRn9c9CDSRSis1M2M3142ZRoVxZ3r/xZDo2rhl0JCkkEf3t7O4bjrjlkJHP894GvB3unBEPDM3n8USKnbSMTP4yYREfzNnIld2a8Pd+ndSDWbJzuNFr9Rw/JSJR9dmiLfzhvfk0qV2ZcUN70LSOmn+VJpEUzRvMrDfg4QkodwDL8nNSd58PxOXnGCLFWfKhdG55ey7TVu7g9rPa8sezNRZOsufuI8PPjwadRaS0ev2HBB7/z1K6NqvNa9fGUbtqhaAjSSGLpGi+iVBf5caEls/+Evh9NEOJlGQ7kg4xbOwslmxO5Il+nRjYQ91hJGdm9nxO77v77YWVRaS0ycx0/v7ZMl77IYFzOzTguau6UKm8Vq8vjXItmt19JzCoELKIlHjxO/YzeMxMdial8uq1cZx1QoOgI0nxMCf83Ac4EXgvvH0lmhMiEjWH0jO48/0FfLpwC4NPbs5DF3egbBndFSytci2azWwcoT7Ke8PbtYF/uPuwaIcTKUnmrt/DdWNnYWaMH96Lk5rWCjqSFBOHV2g1s5uBU9w9Pbz9CvB9kNlESqrEA2kMf3M2PyXs5t7zj+fG01ppGF0pF8mso9isK/a5+x6gS/QiiZQ8Xy3dxtWv/kiNyuWZcHNvFcySV7WBGlm2q4X35YmZtTez+Vke+8zsD0d8pq+ZJWb5zEN5PZ9IcbF570GuHDmduev38NxVJ3HT6a1VMEtEY5rLmFntcLGMmdWJ8HsiArz14zoemrSYTo1r8vqQ7tSrVjHoSFJ8PQnMM7MpgAGnAY/k9WDhtp8nAZhZWULzViYe5aPfu/tFeT2PSHGybMs+hoyZyYFDGYwb2oPebeoFHUmKiEiK338AM8zsA0I/0v2Bv0U1lUgJ4O7848uVvDhlNWceX58Xr+5ClQr6703JO3cfY2afAz3Du+5x960FdPizgDXhBVNESqXpq3dy45tzqFqxHO/fdDInNKyR+5ek1IhkIuAbZjaHX1ad6ufumngikoO0jEzu/WgRH83dyMAeTXns0o6UUw9mKQDhInlSFA59FTA+m/dONrMFwGbgLndfcuQHzGw4MBygWTN1hJHiZ9L8Tdz1wQJa1qvK2KE9aFSrctCRpIiJ9LLXcmDP4c+bWTN3Xx+1VCLF2P5D6dz81hy+X7WTP57djtvPaqOxcFKkhReaugS47yhvzwWau/t+M7sA+DfQ9sgPufsoYBRAXFycRzGuSIFyd16ZFs9Tk5fTs2UdRl0bR83K5YOOJUVQJN0zbgMeBrYRWgnQAAdioxtNpPjZvi+FoWNnsXxrEk9fEcuA7k2DjiQSifOBue6+7cg33H1fltefmdkIM6sXbkcqUqxlZDqPfrKEN2as46LYhvxjQGcqllMPZjm6SK403wG0d/dd0Q4jUpyt3r6fwaNnsjs5ldcGx3FG+/pBR5ISyMxOAdqGxzfHANXcPSGfhx1INkMzzOw4YJu7u5n1INR1SX8fSLGXkpbBHe/O44sl27jh1Jbcd/4JlFEPZslBRMtoA4nRDiJSnM1eu5vr35hNuTLGezf2IraJWspJwTOzh4E4oD0wBigPvEVo0ZO8HrMq8Bvgxiz7bgJw91cITf6+2czSgYPAVe6u4RdSrO1JTuW6cbOYt2EvD110IsNOaRl0JCkGIima44GpZvYf4NDhne7+z6ilEilGJi/eyh3vzqNRrcqMG9qDZnWrBB1JSq7LCfXJnwvg7pvNrHp+DujuyUDdI/a9kuX1i8CL+TmHSFGyYfcBBo+eyca9B3np6q5c0Klh0JGkmIikaF4fflQIP0Qk7I0Za3n44yV0blKL1wfHUVc9mCW6UsPDJBx+vkosIhFatDGRoWNnkZaRyVvX9aRHyzpBR5JiJJKWc48WRhCR4iQz03n6ixW8Mm0NZ5/QgBcGdqFyBU0ekah738xGArXM7AZgGPBqwJlEioWpK7Zzy9tzqV2lAu8O70mb+vm6SSOlUCTdM2KAu4EOQKXD+939zCjmEimyUtMzueejhUyct4mrezbjr5d0UA9mKRTu/oyZ/QbYR2hc80Pu/lXAsUSKvPdnb+C+CYto16A6Y4d2p0GNSrl/SeQIkQzPeBt4D7gIuAkYDOyIZiiRoiopJY2b35rLD6t38udz23NL39bqwSyFKlwkq1AWiYC78/w3q3n265Wc2rYeIwZ1pXol9WCWvImkaK7r7q+b2R3uPg2YZmazoh1MpKjZti+FIWNmsWpbEv/XP5Yr49SDWQqXmSUR6pOfVSIwG7jT3eMLP5VI0ZSekcmDkxYzfuYG+nVtzJP9YqlQTncFJe8iKZrTws9bzOxCQsuoauS8lCqrtiUxZMws9h5I5fUh3Tm9XUzQkaR0+hewEXiH0EJTVwGtCXXTGA30DSyZSBFyIDWdW9+Zx7fLt/P7M1pz1zntdVdQ8i2SovlxM6sJ3Am8ANQA/hDVVCJFyMyE3Vw/bhYVypXlvRtPpmPjmkFHktLrEnfvnGV7lJnNd/d7zOwvgaUSKUJ27j/EsLGzWLwpkccv68jvejUPOpKUEJEUzXvcPZHQLcAzAMwsz430RYqTzxdt4Y735tOkdqgHc9M66sEsgTpgZgOAD8Pb/YGU8GstOCKlXsLOZAaPnsn2pBRGXhPHb05sEHQkKUEiGdzzQoT7REqUMf9N4JZ35tKxUQ0+uqm3CmYpCgYB1wDbgW3h178zs8rArUEGEwnavPV7uOLl6ew/lM47N/RSwSwFLtsrzWZ2MtAbiDGzP2V5qwaghrRSYmVmOk9NXs7I7+I558QGPD+wC5XK6x95CV54ot/F2bz9Q2FmESlKvlq6jdvGz6V+9UqMG9aDlvW07o8UvJyGZ1QAqoU/k7UD+D5CtwRFSpxD6Rn8+YOFfLxgM9f0as4jl3SgbBlNHpGiwcwqAdfxv33zhwUWSiRgb/24jocmLaZj45q8Prg7MdW1MqtER7ZFc5b2cmPdfR2AmZUBqrn7vsIKKFJY9qWkceMbc5gRv4u7z2vPzaerB7MUOW8Cy4Fzgb8SGq6xLNBEIgFxd575cgUvTVnDGe1jeGlQV6pUiGSqlkjeRDKm+Qkzq2FmVYHFwFIz+3OUc4kUqi2JBxnwygxmrd3NPwd05pa+bVQwS1HUxt0fBJLdfRxwIdAz4EwihS41PZM731/AS1PWcFX3prx6bZwKZom6SIrmE8NXli8DPgdaEpp8ki9mVtbM5pnZp/k9lkh+rNyWRL8R09m45yBjhnanX9cmQUcSyc7hvvl7zawjUBOoH2AekUKXlJLGsLGzmDBvE388ux1P9OtEubJatESiL5L/LCtvZuUJFc0vunuamRVEa6M7CN1WrFEAxxI5Zu7OF0u2cfeHC6hUvizv3diLDo3Ug1mKtFFmVht4APiY0LyTB4ONJFJ4Vm9P4rbx81m5LYmn+8cyQCuzSiGKpGgeCawFFgDfmVlzQpMB88zMmhC6rfg34E+5fFykQGVkOpMXb+XFKatZtmUf7RpUY/SQ7jSprZZyUnSF55Tsc/c9wHdAq4AjiRSa5Vv38cK3q/ls0RaqVijH64Pj6NteN1mkcOVaNLv788DzWXatM7Mz8nnefwF38+uuHL9iZsOB4QDNmjXL5+lEID0jk48XbOalKatZsyOZVjFV+ceVnbnkpEaU1609KeLcPdPM7gbeDzqLSGFZvCmRF75dxRdLtlGtYjlu6dua605pRZ2qFYKOJqVQTn2af+fubx3Rozmrf+blhGZ2EbDd3eeYWd/sPufuo4BRAHFxcVrpSvLsUHoGE+Zu4uWpa1i/+wDHH1edF6/uwvkdG6qdnBQ3X5vZXcB7QPLhne6+O7hIIgVvwYa9vPDtKr5etp3qlcpx+1ltGdanBbWqqFiW4OR0pflwZ/BsrwbnUR/gEjO7gFCf0Rpm9pa7/66AzyOlXEpaBu/OXM/I7+LZkphC5yY1eeiiOM46ob46Y0hx9dvw8++z7HM0VENKiDnr9vD8N6uYtnIHNSuX587ftGNwnxbUqFQ+6GgiOfZpHhl+frQgT+ju9wH3AYSvNN+lglkK0v5D6bz94zpe/T6enftT6dGiDk/3j+WUNvVULEux5u4tg84gEg0/xe/i+W9X8d/Vu6hTtQL3nHc815zcnGoV1UZOio6chmc8n917AO5+e8HHEcm7xINpjJu+ltH/TWDvgTRObVuPW89oQ89WdYOOJlIgzKwKocnTzdx9uJm1Bdq7u1p3SrHj7sxYs4vnvlnFTwm7qVetIvdfcAKDejVTz2UpknL6p3JO+LkPcCKhMXQAVwJLC+Lk7j4VmFoQx5LSa9f+Q7z+QwJvzFjH/kPpnH1CA249sw0nNa0VdDSRgjaG0G9z7/D2JuADIM9Fs5mtBZKADCDd3eOOeN+A54ALgAPAEHefm9fzibg7363ayfPfrGLOuj00qFGRhy8+kYE9mlGpfNmg44lkK6fhGeMAzOxm4BR3Tw9vvwJ8XzjxRLK3bV8Ko76L552f1pOSnsEFnRry+75tOLGRWn9LidXa3X9rZgMB3P2AFcyYozPcfWc2750PtA0/egIvo1UIJQ/cnW+Xb+f5b1ezYMNeGtWsxGOXduDKuKYqlqVYiOT+R21CC5Acnp1dLbxPJBAbdh9g5HdreH/WRjLcufSkRtzStw1t6lcLOppItKWaWWVCk/8ws9bAoSif81LgDXd34Eczq2VmDd19S5TPKyVEZqbz1bJtPP/NKpZs3keT2pV5ol8nrujahArl1O5Tio9IiuYngXlmNgUw4DTgkWiGEjmahJ3JjJiymonzNmEG/bs15ebTW9OsrhYlkVLjEWAy0NTM3iY0fG5IPo/pwJfhlV5Hhtt9ZtUY2JBle2N4n4pmyVFmpvP54q288O0qlm9NokXdKjzdP5bLuzRWb3wpliJZ3GSMmX3OL7fj7nH3rdGNJfKLFVuTeGnKaj5duJnyZcvwu17NufH0VjSsWTnoaCKFyt2/NLM5QC9CFzHuyGFYRaROcfdNZlYf+MrMlrv7d8d6EC1IJYdlZDqfLtzMi9+uZtX2/bSKqcqzv+3MxbGNKKdiWYqxiKanhovkSVHOIvIrizYm8uKU0EpQVSuU5YbTWnH9Ka2IqV4x6GgigTCzT4B3gI/dPTm3z0fC3TeFn7eb2USgB6Flug/bBDTNst0kvO/I42hBqlIuPSOTSfNDq67G70ymXYNqPD+wCxd20kJSUjKop4sUObPX7uaFb1czbeUOalQqxx1ntWWoVoISAXiG0AInT5rZLOBd4FN3T8nLwcysKlDG3ZPCr88B/nrExz4GbjWzdwndcUzUeGbJKi0jk4lzN/HilNWs332AExrW4OVBXTm3w3GUUbEsJYiKZikS3J3pa3bxwrer+DF+N3WqVuDu89pzTa/mVNdKUCIAuPs0YJqZlQXOBG4ARhOarJ0XDYCJ4QYc5YB33H2ymd0UPt8rwGeE2s2tJtRybmi+/hBSYhxKz+DDORsZMWUNm/YepFPjmrx6bRxna9VVKaEiKprN7BSgbXh8cwxQzd0TohtNSgN3Z8qK7bzw7Wrmrd9L/eoVefCiExnYo6ma24scRbh7xsWErjh3Bcbl9VjuHg90Psr+V7K8dn69bLeUcilpGbw3awOvTFvDlsQUTmpai8cv60jf9jEqlqVEy7UqMbOHgTigPaHG+uWBtwjN2hbJk8xM54slW3nh29Us3bKPxrUq8/hlHenfrYn6dYpkw8zeJzTmeDLwIjDN3TODTSWlxcHUDN6ZuZ6R09awPekQcc1r83T/WE5pU0/FspQKkVzKuxzoAswFcPfNZlY9qqmkxErPyOTThVt4aUpoVnXLelX5v/6xXKYWRCKReB0Y6O4ZELoLaGYD3V1XgiVqkg+l89aP63j1+3h27k+lV6s6/Ouqkzi5VV0Vy1KqRFI0p7q7h3t4Hp44InJMUtMzmThvIyOmrmHdrgO0b1Bds6pFjpG7f2FmXcIrAg4AEoAJAceSEiopJY03Zqzjte/j2XMgjVPb1uO2M9vSo2WdoKOJBCKSovl9MxsJ1DKzG4BhwKvRjSUlxeGxbyOnrWFzYgqxTWoy6ppunH1CA82qFomQmbUDBoYfO4H3AHP3MwINJiVS4sE0xv53LaP/m0DiwTTOaB/DbWe1pWszLQYspVski5s8Y2a/AfYRGtf8kLt/FfVkUqxt3nuQTxZs5rUfEtgRHvv2xBWxnNZWY99E8mA58D1wkbuvBjCzPwYbSUqavQdSGf1DAmP+u5akQ+mcfUIDbj+rDbFNagUdTaRIiHRxk68AFcqSrdT0TGav3c3UlTuYumI7K7ftB+CUNvV4YWAXeraso2JZJO/6AVcBU8xsMqH+zPoXSvLF3Vm9fT/TVu5g2sod/JSwm9T0TM7veBy3ntmGDo1qBh1RpEiJpHtGEnDk6k6JwGzgznDLIimFNu45wLSVO5i6YgfTV+8kOTWD8mWNHi3r0L9bE85oX5+2DTRnVCS/3P3fwL/Dc0ouBf4A1Dezl4GJ7v5loAGl2NiXksb01TtDhfKKHWxODK2L07Z+Na7p1ZwBcU1pf5x+t0WOJpIrzf8CNhJautUIXe1oTaibxmigb7TCSdFyKD2DWQl7mLpiO1NX7mD19tDV5Ma1KnNZl8b0bV+f3q3rUrWi+iuLREN46ex3gHfMrDZwJXAPoKJZjioz01m6Zd/PRfKc9XvIyHSqVyxHnzb1uO2sGE5rF0PjWpWDjipS5EVS3Vzi7lmb348ys/nufo+Z/SVawaRo2LD7AFNX7mDaiu1MX7OLA6kZVChbhh4t63BV96b0bR9D65hqGnohUsjcfQ8wKvwQ+dnu5FS+XxUqkr9btYOd+1MB6Ni4Bjed3orT29WnS7NaavMpcowiKZoPmNkA4MPwdn8gJfz6yGEbUsylpGUwM2E3U1fsYOrK7cTvSAagaZ3KXNG1CX3bx3By67parU9EpIhIz8hkwca9TFsRGpu8cFMi7lC7SnlOaxfD6e1iOLVtDDHVKwYdVaRYi6TyGQQ8B4wgVCT/CPwuvJTrrVHMJoVk3a7kn8cmz1izi4NpGVQoV4aeLeswqGdz+raPoVW9qrqaLCJSRGxNTOG78AS+71ftYF9KOmUMujSrzR/Pbsfp7WLo2Lim+uCLFKBIWs7FAxdn8/YPBRtHCkNKWgY/xu9iaviqRMLO0NXk5nWrMCCuCX3b16dXq7pUrqDlrEVEioLU9Exmr9v989Xk5VuTAGhQoyLndTyO09vV55Q29ahZpXzASUVKrki6Z1QCrgM6AJUO73f3YVHMJQUsYWcy08IT+Gas2cWh9EwqlitDr1Z1ufbk5vRtX5+W9bTYo4hIUbF+1wGmhccmT1+zkwPhDkXdW9ThvvOP5/T2MbRvUF13AUUKSSTDM94k1Fj/XOCvhIZrLItmKMm/g6mHryaHCuV1uw4A0LJeVQb2aEbf9jH0alWXSuV1NVlEpCg4mJrBjwm7QhP4Vu4gfuev55Sc3i40p0QdikSCEcm/eW3c/Uozu9Tdx5nZO4RWppIixN1J2JkcnsC3g5/iQ1eTK5X///buPEiu6rrj+PenfUfLtACjfZlR2CLBsGgdLTEFAVMJRVgKcAjlCNvYBscuIE6cEFc5ZccplyEUuGQwITHYxgQwpjBbLITYQUKITRJiiyQLtKAFCUkjzZz88d4MwzBS94ym+7Wmf5+qqe553X3feSPp6M59997TjanjhnHZ9LHMrskxephHk83MykFE8NbGHc1T5ZqKi/TpmdwFvCS9CzhmWD+PJpuVgUI6zXvTx62SjgXeB4YXLyQr1Mf1+3jmrc3NO12s+XAXAONy/bnolNHU1eQ4ZexQjyabmZWJj3bv5anVm1m0KhlNXrc1ydsT0uIiddU5TnbeNitLhXSaF6Sb6P8jcD8wAPhuUaOyNiWjEjt5fOWGT41K9O3ZnWnjhzF/ZrL/5qhh/bIO1czMaFVcZNVGlr63hX2NwYDePZg+YRhXzJnArOoqRgxx3jYrdwfsNEvqBmxPN9F/Ahh3sCeUNBL4L+Bwki3sFkTE9Qfbble1c88+nn5rc3NHee2WT0YlvnhqMpp80hiPSpiZlYvm4iKrNvLEqk1s2rEHgGM+N4j5s8ZRV53jhNFDXFzE7BBzwE5zRDRKuhq4qxPPuQ/4VkQslTQQWCLp0Yh4vRPPcciKCFZv2NE85eKFd7ZQ39BIv17dmTa+ii/XjaeuOsfIoR6VMDMrBw2NwbI1W5tHk5ev3dpcXGTmxLS4SHUVwwf2yd+YmZWtQqZnPCbp28CvgZ1NByPiw46cMCLWA+vT5x9JegM4CqjYTvOOPft4tVRPbgAADgpJREFUavUmHl/56Tlu1YcP4K+nJQtBascMoXcPjyabmZWDD7bvbu4kP/nmJrbt2ks3weSRg7lqXjV1NTmOc3ERsy6lkE7z+enjFS2OBZ0zVWMMMAV4ro3X5gPzAUaNGnWwpyorEcGqD3Yk28Gt3MiL733I3oagf6/uTJ9QxRVzJlBXk+OowX2zDtXMzGhRXGRVsm9yU3GR4QN7c9rRh1NXk2PGhCoG9+uVcaRmViyFVAQcW4wTSxoA/A9wVURsb+O8C4AFALW1tVGMGEopWTG9qXlrofXbdgMw6YiBXDZ9LHU1OWpHD6VXD89xM7PSKGSNiaTZwG+Bd9JD90TE90oZZ1bWfPgxj6ed5Gfe2sTOtLhI7eihXHvGJOqqc0w6wsVFzCpFIRUB+wF/B4yKiPmSJgI1EfFAR08qqSdJh/mOiLino+2Us4hgxfsfJXOTV25gSYsV0zMmVHHlvBx1NTmOPMyjyWaWmULXmCyOiLMyiK+k9ldcZMSQvvzlCUdRVz2cqeOHMcDFRcwqUiH/8m8DlgDT0u/XAb8BOtRpVvIr+a3AGxHx4460Ua62797Lk29uYlE6mvz+9k9Gk780cxyza3Kc6BXTZlYmKn2NSdM2nk1zk5uKQvXu0Y2p45PiInXVOcZW9fdospkV1GkeHxHnS7oQICI+1sFlj+nAJcArkpalx74TEQ8eRJuZiEj233x8ZXL7bsn/baGhMRjYpwczJ1Yxu3o4s6pzHHGYV0ybWXk70BoTYKqkl4E/At+OiNdKGFqn+mj3Xp5+a3Pz3OSWxUUudnERMzuAQjrN9ZL6ksx3Q9J4YE9HTxgRTwKH7K/s2z7ey+LVG5tHkzd8lPwojj5yEJfPGsfsmuFMGTXYo8lmdsjIs8ZkKTA6InZI+nPgPmBiG22U5eLtpsGNpk7yEhcXMbMOKqTTfB3wEDBS0h0kI8WXFjGmstJUzalpp4uX1myloTEY1KcHM6tzzK5O9uAcPsijyWZ26Mm3xqRlJzoiHpR0k6SqiNjU6n1ls3h7y856Fq9Opso98eZGNn7k4iJmdvAK2T3jEUlLgFNJRoivbJ0su5qtH9fzRIu5yU3VnI49ahBfqRvP7Jock0cOpocTrpkdwgpZYyLpCOCDiAhJJwPdgM0lDDOvhsbg5bVbm3P2y2lxkcH9ejLLxUXMrJMUsnvG74A7gfsjYme+9x+KGhuDV/+4rXmni2VrttIYcFjfnsxKR5OdcM2sC2pzjQkwCiAifgqcC3xF0j5gF3BBRGS+DeiGFsVFFru4iJmVQCHTM/6dpMDJDyS9APwKeCAidhc1siLbsrOeJ97c2Hz7btOOegCOH3EYX5szgbqa4UweOdgJ18y6rELWmETEjcCNpYlo/+r3NbLkvS3NHeU31iezRlxcxMxKpZDpGYuARZK6A3OBvwV+DgwqcmydqrExWL5uW/Pc5Kbbd0P6paPJNTlmTsxRNaB31qGamRlJcZGmTvLTq5PiIj26idoxQ7jm9EnMrnFxETMrnYJ2aE93z/gCyYjzCcDtxQyqs2zesafFaPImPtxZjwTHjxjMN+ZOZHZNjuNHeDTZzKwc7N7bwLNvb27uKL+98ZPiIn8x5SjqqnNMm1Dl4iJmlolC5jTfBZxMsoPGjcCiiGgsdmAHY1d9AxcseIbl67YRAUP796Iu3eVi5sQqhnk02cysrFxz93LuW7auubjIqeOGcfEpo6mryTHOxUXMrAwU8uv6rcCFEdEAIGmGpAsj4orihtZxfXt1Z0xVf+ZOOpzZ6WKQbh5NNjMrWyOG9OWitJN8iouLmFkZKmRO88OSpqQVAc8D3gE+s5dnubn+gilZh2BmZgX6+rzP1EsxMysr++00S6oGLky/NgG/BhQRc0oUm5mZmZlZWTjQSPMKYDFwVkSsBpD0zZJEZWZmZmZWRg5U0u4cYD2wUNLPJM0jz36eZmZmZmZd0X47zRFxX0RcAEwCFgJXAcMl3SzptFIFaGZmZmaWtQONNAMQETsj4s6I+AIwAngJuKbokZmZmZmZlYm8neaWImJLRCyIiHnFCsjMzMzMrNy0q9NsZmZmZlaJ3Gk2MzMzM8tDEZF1DHlJ2gi8l3UcRVRFshd2pfD1dm2+3k8bHRG5UgVTDpyzuxxfb9fm6/2sNvP2IdFp7uokvRgRtVnHUSq+3q7N12tdXaX9mft6uzZfb+E8PcPMzMzMLA93ms3MzMzM8nCnuTwsyDqAEvP1dm2+XuvqKu3P3Nfbtfl6C+Q5zWZmZmZmeXik2czMzMwsD3eazczMzMzycKc5I5L6SHpe0suSXpP0L1nHVAqSukt6SdIDWcdSbJLelfSKpGWSXsw6nlKQNFjS3ZJWSHpD0tSsYyoWSTXpn23T13ZJV2UdlxVPJebtSsrZUHl52zm7fTm7R7GCs7z2AHMjYoeknsCTkn4fEc9mHViRXQm8AQzKOpASmRMRlbRp/PXAQxFxrqReQL+sAyqWiFgJTIakYwGsA+7NNCgrtkrM25WWs6Gy8rZzdjt4pDkjkdiRftsz/erSqzIljQDOBG7JOhbrfJIOA2YBtwJERH1EbM02qpKZB7wVEV25Cl7Fq7S87ZzdtTlntz9nu9OcofS21zJgA/BoRDyXdUxF9hPgaqAx60BKJIBHJC2RND/rYEpgLLARuC29nXuLpP5ZB1UiFwC/zDoIK74Ky9uVlrOhsvK2c3Y7udOcoYhoiIjJwAjgZEnHZh1TsUg6C9gQEUuyjqWEZkTECcAZwBWSZmUdUJH1AE4Abo6IKcBO4NpsQyq+9Jbm2cBvso7Fiq9S8naF5myorLztnN1O7jSXgfR2yELg9KxjKaLpwNmS3gV+BcyV9ItsQyquiFiXPm4gmTd1crYRFd1aYG2Lkbe7SRJyV3cGsDQiPsg6ECudCsjbFZezoeLytnN2O7nTnBFJOUmD0+d9gc8DK7KNqngi4u8jYkREjCG5LfKHiLg447CKRlJ/SQObngOnAa9mG1VxRcT7wBpJNemhecDrGYZUKhfiqRkVoZLydqXlbKi8vO2c3X7ePSM7RwK3pys4uwF3RURFbOlTIQ4H7pUEyb+zOyPioWxDKomvA3ekt7/eBv4m43iKKv2P9fPA5VnHYiXhvN21VWLeds5uz+ddRtvMzMzM7MA8PcPMzMzMLA93ms3MzMzM8nCn2czMzMwsD3eazczMzMzycKfZzMzMzCwPd5rtoEhqkLRM0muSXpb0LUnd0tdqJd2QUVxPF6HN/5R0bvr8FklHp8+/04G2+kpalJbkHSOpw3uBSnpM0pCOft7MKodztnO2dZw7zXawdkXE5Ig4hmTvwzOAfwaIiBcj4htZBBUR04rc/pciomkT+HYnYOAy4J6IaOiEcP4b+GontGNmXZ9ztnO2dZA7zdZp0rKj84GvKTFb0gMAkq6TdLukxZLek3SOpH+T9IqkhyT1TN93Yvrb/BJJD0s6Mj3+uKQfSnpe0ipJM9Pjx6THlklaLmlienxH+ihJP5L0anqu89Pjs9M275a0QtIdSne0l/RPkl5IP7Og6XhL6WdrJf0A6Jue/w5J35N0VYv3fV/SlW38uC4CfttGu30k3ZbG+pKkOenxfpLukvS6pHslPSepNv3Y/SQVjszMCuac7Zxt7eNOs3WqiHgb6A4Mb+Pl8cBc4GzgF8DCiDgO2AWcmSbh/wDOjYgTgZ8D32/x+R4RcTJwFenICPBl4PqImAzUAmtbnfMcYDLwp8CfAT9qSurAlLSto4FxwPT0+I0RcVJEHAv0Bc46wPVeyycjNxelMX8RQMktzwvSa22mpPLSuIh4t40mr0iajeNIkurtkvqQjEpsiYijge8CJ7aIYQvQW9Kw/cVpZtYW52znbCucy2hbKf0+IvZKeoUkSTeVJ30FGAPUAMcCj6YDBd2B9S0+f0/6uCR9P8AzwD9IGkFy6+zNVuecAfwyvaX2gaRFwEnAduD5iFgLIGlZ2uaTwBxJVwP9gKHAa8DvCrnAiHhX0mZJU0hKsr4UEZtbva0K2LqfJmaQ/CdERKyQ9B5QnR6/Pj3+qqTlrT63Afgc0PpcZmYd5ZydcM42wJ1m62SSxgENJAnhT1q9vAcgIhol7Y1Parg3kvxdFPBaREzdT/N70seG9P1ExJ2SngPOBB6UdHlE/KHAcPe0eN4A9EhHCG4CaiNijaTrgD4FttfkFuBS4AiSUYzWdnWgzXz6pO2amRXMORtwzrYCeXqGdRpJOeCnJLfKIt/727ASyEmamrbXU9Ixec45Dng7Im4gmW92fKu3LAbOV7LiOQfMAp4/QJNNiXGTpAHAuQXEvbdpfl/qXuB0ktGRh1u/Ob011z1N9q0tJpk7h6RqYBTJz+Up4Lz0+NHAcU0fSOfvHQG8W0CsZmaAc3aL752zrSAeabaD1Te9TdYT2EeyKvjHHWkoIuqVbA90g6TDSP5+/oTkVtv+nAdcImkv8D7wr61evxeYCrwMBHB1RLwvadJ+Ytgq6WfAq2l7LxQQ+gJguaSlEXFReh0Lga0HWGn9CMntu8daHb8JuDm9HboPuDQi9ki6iWSu3OvACpKfybb0MycCz0bEvgJiNbPK5pztnG0dpI79cmlm+5MuJlkK/FUb8/Wa3nMC8M2IuKTANrsDPSNit6TxJIm7Jk321wP3R8T/dtIlmJlVDOdsK5RHms06UXob7gHg3v0lX4CIWCppoaTuBe772Q9YmN5SFPDViKhPX3vVydfMrP2cs609PNJsZmZmZpaHFwKamZmZmeXhTrOZmZmZWR7uNJuZmZmZ5eFOs5mZmZlZHu40m5mZmZnl8f+OV537w/FfdQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 864x288 with 2 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OnbqVOFYGUQ_"
},
"source": [
"Points in high dimensional spaces sampled from $[0, 1)^d$ tend to be far away from the origin and each other."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RDlVKKA2ZOSL"
},
"source": [
"### Subsetting"
]
},
{
"cell_type": "code",
"metadata": {
"id": "PNksDqj-ZPmz",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "96f0db3d-2091-4661-aaca-5b4158fea5bd"
},
"source": [
"a = np.random.randn(7,7)\n",
"print(a)\n",
"a[a < 0] = 0\n",
"a"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"[[ 0.39968399 1.49906021 -1.51424546 -1.69419294 -0.51995498 0.71645821\n",
" 0.9006934 ]\n",
" [-1.52950457 0.89882558 1.21495254 0.56172426 1.11672662 -1.14913005\n",
" -0.64828375]\n",
" [ 0.1155437 -0.89443134 1.84528822 -0.10637707 1.46539325 1.21498806\n",
" 0.82827192]\n",
" [-2.14755008 -0.04051744 -1.88366634 0.47326144 0.95436352 -0.82562433\n",
" 0.74747166]\n",
" [-0.57082256 -1.20724777 -1.15340429 0.6954061 0.90065367 -0.3185223\n",
" -0.9356916 ]\n",
" [ 0.86490831 0.42468098 1.81276608 -2.48497615 1.46263681 -0.26064952\n",
" 2.29424169]\n",
" [ 0.62438524 -1.08074034 1.32001285 1.01419146 -0.12433166 1.56485127\n",
" 0.30849677]]\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[0.39968399, 1.49906021, 0. , 0. , 0. ,\n",
" 0.71645821, 0.9006934 ],\n",
" [0. , 0.89882558, 1.21495254, 0.56172426, 1.11672662,\n",
" 0. , 0. ],\n",
" [0.1155437 , 0. , 1.84528822, 0. , 1.46539325,\n",
" 1.21498806, 0.82827192],\n",
" [0. , 0. , 0. , 0.47326144, 0.95436352,\n",
" 0. , 0.74747166],\n",
" [0. , 0. , 0. , 0.6954061 , 0.90065367,\n",
" 0. , 0. ],\n",
" [0.86490831, 0.42468098, 1.81276608, 0. , 1.46263681,\n",
" 0. , 2.29424169],\n",
" [0.62438524, 0. , 1.32001285, 1.01419146, 0. ,\n",
" 1.56485127, 0.30849677]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 38
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YUe2SqY8nYbS"
},
"source": [
"The next cell describes an important concept, that of a probability distribution over protein sequence space. At each residue location along the sequence of length $l$, there are $d=21$ numbers that specify the probability of each of the $21$ characters occuring at that position. The function `mysoftmax` constructs these probabilities by normalizing them so that they are all nonzero and the $21$ numbers of each residue position sum to one."
]
},
{
"cell_type": "code",
"metadata": {
"id": "30a1aY7yZSuD"
},
"source": [
"def mysoftmax(X):\n",
" eX = np.exp(X)\n",
" return eX / np.sum(eX, axis=2, keepdims=True)\n",
"\n",
"a = np.random.randn(2, 6, 21)\n",
"a_prob = mysoftmax(a)\n",
"\n",
"def get_samples(Xt_p):\n",
" Xt_sampled = np.zeros_like(Xt_p)\n",
" b, l, d = Xt_p.shape\n",
" for i in range(b):\n",
" for j in range(l):\n",
" p = Xt_p[i, j]\n",
" # k = np.random.choice(range(len(p)), p=p)\n",
" k = np.argmax(p)\n",
" Xt_sampled[i, j, k] = 1.\n",
" return Xt_sampled"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "cTbaUm19m1hV"
},
"source": [
"### Working with Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DzSRiwbVVZbQ"
},
"source": [
"To represent biological sequences on the computer, we need a way to convert a sequence of letters to an array of numbers. The simplest way is one-hot-encoding, where each letter is assigned an index. In the code below, A is index 1, R index 2, N index 3. So the sequence 'ARN' is represented on the computer by a $3 \\times 21$ matrix:\n",
"\n",
"$$\n",
"\n",
"\\begin{pmatrix}\n",
" 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\\\\n",
" 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\\\\n",
" 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \n",
"\\end{pmatrix}\n",
"\n",
"$$"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ftZC5puule2l",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4403a4d6-4964-49c1-8f26-9b8c1c6f9448"
},
"source": [
"import numpy as np\n",
"\n",
"AA = ['a', 'r', 'n', 'd', 'c', 'q', 'e', 'g', 'h', 'i', \n",
" 'l', 'k', 'm', 'f', 'p', 's', 't', 'w', 'y', 'v', '-']\n",
"AA_IDX = {AA[i]:i for i in range(len(AA))}\n",
"IDX_AA = {i:AA[i].upper() for i in range(len(AA))}\n",
"\n",
"def one_hot_encode_aa(aa_str, pad=None):\n",
" aa_str = aa_str.lower()\n",
" M = len(aa_str)\n",
" aa_arr = np.zeros((M, 21), dtype=int)\n",
" for i in range(M):\n",
" aa_arr[i, AA_IDX[aa_str[i]]] = 1\n",
" return aa_arr\n",
"\n",
"def get_X(seqs):\n",
" M = len(seqs[0])\n",
" N = len(seqs)\n",
" X = []\n",
" for i in range(N):\n",
" try:\n",
" X.append(one_hot_encode_aa(seqs[i]))\n",
" except KeyError:\n",
" pass\n",
" return np.array(X)\n",
"\n",
"seq = \"MNFPRA\"\n",
"one_hot_encode_aa(seq)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],\n",
" [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SRR4Q1ycnYC6"
},
"source": [
"Exercises:\n",
"\n",
"\n",
"1. What does $N$, $M$ represent in the ```get_X``` function?\n",
"2. Write code to convert the one-hot encoded matrix $X$ back into sequences. Assert all the sequences are identical to the original.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qtbiiaCPXLok"
},
"source": [
"Usually data will come in a comma separated values (csv) file, where each line has elements separated by commas. An example of loading in protein sequence/function data is given below."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Qn9eBVUkqX83",
"outputId": "c776aae9-5837-46be-bf6d-fb11521de98d"
},
"source": [
"!git clone https://github.com/igemto-drylab/CSBERG-ML.git\n",
"%cd CSBERG-ML\n",
"from util import *"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Cloning into 'CSBERG-ML'...\n",
"remote: Enumerating objects: 104, done.\u001b[K\n",
"remote: Counting objects: 100% (104/104), done.\u001b[K\n",
"remote: Compressing objects: 100% (99/99), done.\u001b[K\n",
"remote: Total 104 (delta 46), reused 19 (delta 2), pack-reused 0\u001b[K\n",
"Receiving objects: 100% (104/104), 16.97 MiB | 15.49 MiB/s, done.\n",
"Resolving deltas: 100% (46/46), done.\n",
"/content/CSBERG-ML\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "HyByq5RfmkWc"
},
"source": [
"import pandas as pd\n",
"\n",
"df = pd.read_csv(\"petase_activity.csv\")\n",
"X = get_X(list(df['sequence'])[:212])\n",
"X = X.reshape(-1, 298, 21)\n",
"y = np.array(list(df['relative_activity']))[:212]\n",
"y = np.log(y + 0.001)\n",
"y = 2*((y - np.min(y)) / (np.max(y) - np.min(y))) - 1"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "3zZDnThMaTzM"
},
"source": [
"Exercise:\n",
"\n",
"1. What does the last two lines of the code do?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ybI_EyoLbFA7"
},
"source": [
"A common data format for biological sequences is FASTA, which like many other file formats, is just a text file with special rules. Below we define useful functions that can read a FASTA file into a PyTorch tensor and write samples in protein sequence space into a FASTA file."
]
},
{
"cell_type": "code",
"metadata": {
"id": "Bq_ypuDYYgLw"
},
"source": [
"import torch\n",
"\n",
"def read_fasta(fname):\n",
" seqs = []\n",
" s = \"\"\n",
" with open(fname) as f:\n",
" line = f.readline()\n",
" while line:\n",
" if line.startswith(\">\"):\n",
" if s != \"\":\n",
" seqs.append(s)\n",
" s = \"\"\n",
" elif len(line) > 0:\n",
" s += line.strip()\n",
" line = f.readline()\n",
" seqs.append(s)\n",
"\n",
" X = torch.tensor(get_X(seqs))\n",
"\n",
" return X\n",
"\n",
"\n",
"def save_fasta(X_p, fname, sampling='max'):\n",
" seqs = \"\"\n",
" if torch.is_tensor(X_p):\n",
" X_p = X_p.cpu().numpy()\n",
" b, l, d = X_p.shape\n",
"\n",
" # nchar = 1\n",
" for i in range(b):\n",
" seqs += \">{}\\n\".format(i)\n",
" for j in range(l):\n",
" p = X_p[i, j]\n",
" if sampling == 'max': # only take the one with max probability\n",
" k = np.argmax(p)\n",
" elif sampling == 'multinomial': # sample from multinomial\n",
" k = np.random.choice(range(len(p)), p=p)\n",
" aa = IDX_AA[k]\n",
" if aa != '-':\n",
" seqs += IDX_AA[k]\n",
" # if nchar % 60 == 0: # optional\n",
" # seqs += \"\\n\"\n",
" seqs += \"\\n\"\n",
" with open(fname, \"w\") as f:\n",
" f.write(seqs)\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "bcNVpstMcbDh"
},
"source": [
"Amino acid Scrabble seems like a good idea."
]
},
{
"cell_type": "code",
"metadata": {
"id": "dUWA7NqAZFPc",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "24a7a7f2-14f3-4b74-b6c5-228a1aa683de"
},
"source": [
"fragments = np.array([\"SILKEN\", \"ANGELS\", \"AGEIST\", \"SPACES\",\n",
" \"WINTER\", \"SHIVER\", \"EARTHY\", \"PEARLS\", \"DAMSEL\", \"DANCES\"])\n",
"X = get_X(fragments)\n",
"print(X.shape)\n",
"\n",
"X_p = np.random.randn(2,6,21)\n",
"X_p = mysoftmax(X_p)\n",
"print(X_p.shape)\n",
"\n",
"save_fasta(X_p, \"random.fa\", sampling='multinomial')\n",
"\n",
"X_p = mysoftmax(X)\n",
"save_fasta(X_p, \"multi.fa\", sampling='multinomial')\n",
"save_fasta(X_p, \"max.fa\", sampling='max')\n",
"\n",
"# this should be identical to max.fa\n",
"save_fasta(read_fasta(\"max.fa\"), \"match.fa\", sampling='max')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"(10, 6, 21)\n",
"(2, 6, 21)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7cSijmVpctWT"
},
"source": [
"PyTorch comes with the convenient `Dataset` and `DataLoader` objects for iterating over batches of your data. To split your data into train/test sets, a good start is using `sklearn`'s `train_test_split` function. However, a random split is not always the best idea. We will cover subtleties in splitting protein sequence data when we discuss generative models. The line `plt.rcParams['figure.dpi'] = 300` makes your plots high resolution."
]
},
{
"cell_type": "code",
"metadata": {
"id": "2tAzpj9Lc6ej"
},
"source": [
"import json\n",
"import random\n",
"import math\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
"from torch.utils.data.sampler import SubsetRandomSampler\n",
"from sklearn.model_selection import train_test_split, ShuffleSplit\n",
"from scipy import stats\n",
"plt.rcParams['figure.dpi'] = 350"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "DuXFjQQaadaF",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "01338e93-dc94-4c74-f419-3fcafcb91098"
},
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"class SequenceData(Dataset):\n",
"\n",
" def __init__(self, X):\n",
" if not torch.is_tensor(X):\n",
" self.X = torch.from_numpy(X)\n",
" else:\n",
" self.X = X\n",
"\n",
" def __len__(self):\n",
" return len(self.X)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.X[idx]\n",
"\n",
"X = read_fasta(\"max.fa\")\n",
"X = X.type(torch.FloatTensor).to(device)\n",
"b, seqlen, d = X.size()\n",
"X = X.view(b, seqlen*d)\n",
"\n",
"dataset = SequenceData(X)\n",
"trainset, testset = train_test_split(list(range(X.size(0))), test_size=.2)\n",
"print(\"Sequence length: \", seqlen)\n",
"print(\"Training on \")\n",
"print(fragments[trainset])\n",
"print(\"Testing on \")\n",
"print(fragments[testset])"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Sequence length: 6\n",
"Training on \n",
"['DAMSEL' 'EARTHY' 'PEARL' 'ANGELS' 'DANCES' 'SPACES' 'SILKEN' 'SHIVER']\n",
"Testing on \n",
"['WINTER' 'AGEIST']\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3lKRDduMLUEs",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a961bd97-6757-4bd9-d4fe-90d522f2f769"
},
"source": [
"X.shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([10, 126])"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EK4kww6vhjmT"
},
"source": [
"We've only covered sequence data. In section 7 we will work with protein structure data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-jrx7bWJqRAc"
},
"source": [
"Exercises:\n",
"\n",
"\n",
"1. You're designing a signal peptide with 6 amino acids. Experiments show that for it to function, there must be a lysine at the second position. Sample 1000 sequences of length 6 from a unit Gaussian. Normalize it to valid probabilities using ```mysoftmax()```. Sample from this distribution over sequences. **Without using for loops**, filter the result by extracting those sequences with a lysine at the second position.\n",
"2. Visualize the distribution using a logo, before and after the filtering.\n",
"\n"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment