Skip to content

Instantly share code, notes, and snippets.

@nmiano1111
Last active August 29, 2015 14:26
Show Gist options
  • Select an option

  • Save nmiano1111/c180efd1a37d7faebbff to your computer and use it in GitHub Desktop.

Select an option

Save nmiano1111/c180efd1a37d7faebbff to your computer and use it in GitHub Desktop.
Translation of the script from this blog post: http://iamtrask.github.io/2015/07/12/basic-python-network/ from python and numpy to scala and nd4j (http://nd4j.org/).
import org.nd4j.api.linalg.DSL._
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.lossfunctions.LossFunctions
import org.nd4j.linalg.ops.transforms.Transforms
/**
* Created by nmiano on 7/27/15.
*/
object App {
def layer(rows: Int, cols: Int): INDArray = (Nd4j.rand(rows, cols) * 2) - 1
def main(args: Array[String]): Unit = {
val X = Nd4j.create(Array[Array[Float]](
Array[Float](0, 0, 1),
Array[Float](0, 1, 1),
Array[Float](1, 0, 1),
Array[Float](1, 1, 1)))
val y = Nd4j.create(Array[Float](0, 1, 1, 0), Array[Int](4, 1))
val syn0 = layer(3, 4)
val syn1 = layer(4, 1)
//train
(0 until 60000).foreach((j) => {
//feed forward
val l0 = X
val l1 = Transforms.sigmoid(l0.dot(syn0))
val l2 = Transforms.sigmoid(l1.dot(syn1))
//how much did we miss the target value?
val l2Error = y - l2
//logging
if (j % 10000 == 0) println("Error: " + Transforms.abs(Nd4j.mean(l2Error)))
//in what direction is the target value?
//were we really sure? if so, don't change too much.
val l2Delta = l2Error * (l2 * ((l2 * (-1)) + 1))
//how much did each l1 value contribute to the l2 error (according to the weights)?
val l1Error = l2Delta.dot(syn1.T)
//in what direction is the target l1?
//were we really sure? if so, don't change too much.
val l1Delta = l1Error * (l1 * ((l1 * (-1)) + 1))
syn1 += l1.T.dot(l2Delta)
syn0 += l0.T.dot(l1Delta)
})
//playing around with it
val n0 = Nd4j.create(Array[Float](0, 1, 1))
val n1 = Transforms.sigmoid(n0.dot(syn0))
val n2 = Transforms.sigmoid(n1.dot(syn1))
println(n2)
val z0 = Nd4j.create(Array[Float](1, 1, 1))
val z1 = Transforms.sigmoid(z0.dot(syn0))
val z2 = Transforms.sigmoid(z1.dot(syn1))
println(z2)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment