Last active
August 29, 2015 14:26
-
-
Save nmiano1111/c180efd1a37d7faebbff to your computer and use it in GitHub Desktop.
Translation of the script from this blog post: http://iamtrask.github.io/2015/07/12/basic-python-network/ from python and numpy to scala and nd4j (http://nd4j.org/).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import org.nd4j.api.linalg.DSL._ | |
| import org.nd4j.linalg.api.ndarray.INDArray | |
| import org.nd4j.linalg.factory.Nd4j | |
| import org.nd4j.linalg.lossfunctions.LossFunctions | |
| import org.nd4j.linalg.ops.transforms.Transforms | |
| /** | |
| * A very simple neural network implementation using Scala and | |
| * ND4J. The network is trained to recognize valid xors. | |
| */ | |
| object ToyNet { | |
| def layer(rows: Int, cols: Int): INDArray = (Nd4j.rand(rows, cols) * 2) - 1 | |
| def main(args: Array[String]): Unit = { | |
| val X = Nd4j.create(Array[Array[Float]]( | |
| Array[Float](0, 0, 1), | |
| Array[Float](0, 1, 1), | |
| Array[Float](1, 0, 1), | |
| Array[Float](1, 1, 1))) | |
| val y = Nd4j.create(Array[Float](0, 1, 1, 0), Array[Int](4, 1)) | |
| val syn0 = layer(3, 4) | |
| val syn1 = layer(4, 1) | |
| //train | |
| (0 until 60000).foreach((j) => { | |
| //feed forward | |
| val l0 = X | |
| val l1 = Transforms.sigmoid(l0.dot(syn0)) | |
| val l2 = Transforms.sigmoid(l1.dot(syn1)) | |
| //how much did we miss the target value? | |
| val l2Error = y - l2 | |
| //logging | |
| if (j % 10000 == 0) println("Error: " + Transforms.abs(Nd4j.mean(l2Error))) | |
| //in what direction is the target value? | |
| //were we really sure? if so, don't change too much. | |
| val l2Delta = l2Error * (l2 * ((l2 * (-1)) + 1)) | |
| //how much did each l1 value contribute to the l2 error (according to the weights)? | |
| val l1Error = l2Delta.dot(syn1.T) | |
| //in what direction is the target l1? | |
| //were we really sure? if so, don't change too much. | |
| val l1Delta = l1Error * (l1 * ((l1 * (-1)) + 1)) | |
| syn1 += l1.T.dot(l2Delta) | |
| syn0 += l0.T.dot(l1Delta) | |
| }) | |
| //playing around with it | |
| val n0 = Nd4j.create(Array[Float](0, 1, 1)) | |
| val n1 = Transforms.sigmoid(n0.dot(syn0)) | |
| val n2 = Transforms.sigmoid(n1.dot(syn1)) | |
| println(n2) | |
| val z0 = Nd4j.create(Array[Float](1, 1, 1)) | |
| val z1 = Transforms.sigmoid(z0.dot(syn0)) | |
| val z2 = Transforms.sigmoid(z1.dot(syn1)) | |
| println(z2) | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment