Last active
March 19, 2018 16:46
-
-
Save reiinakano/befceb4df788b86b1b6720f457df072d to your computer and use it in GitHub Desktop.
dljs xor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| const hiddenNumNeurons = 20; | |
| const hidden2NumNeurons = 5; | |
| const weights = dl.variable(dl.randomNormal([2, hiddenNumNeurons])); | |
| const biases = dl.variable(dl.zeros([hiddenNumNeurons])); | |
| const weights2 = dl.variable(dl.randomNormal([hiddenNumNeurons, hidden2NumNeurons])); | |
| const biases2 = dl.variable(dl.zeros([hidden2NumNeurons])); | |
| const outWeights = dl.variable(dl.randomNormal([hidden2NumNeurons, 1])); | |
| const outBias = dl.variable(dl.zeros([1])); | |
| const learningRate = 0.01; | |
| const optimizer = dl.train.adam(learningRate); | |
| const num_iterations = 200; | |
| /* | |
| * Given an input, have our model output a prediction | |
| */ | |
| function predict(input) { | |
| return dl.tidy(() => { | |
| const x = dl.tensor2d(input, [1, 2]); | |
| const hidden = x.matMul(weights).add(biases).relu(); | |
| const hidden2 = hidden.matMul(weights2).add(biases2).relu(); | |
| const out = hidden2.matMul(outWeights).add(outBias).sigmoid().asScalar(); | |
| return out; | |
| }); | |
| } | |
| /* | |
| * Calculate the loss of our model's prediction vs the actual label | |
| */ | |
| function loss(prediction, actual) { | |
| // Having a good error metric is key for training a machine learning model | |
| if (actual == 1) { | |
| return prediction.log().neg(); | |
| } else { | |
| return dl.scalar(1).sub(prediction).log().neg(); | |
| } | |
| } | |
| /* | |
| * This function trains our model asynchronously | |
| */ | |
| async function train(numIterations, done) { | |
| let currentIteration = 0; | |
| let xs, ys, cost; | |
| [xs, ys] = getNRandomSamples(numIterations); | |
| for (let iter = 0; iter < numIterations; iter++) { | |
| cost = optimizer.minimize(() => { | |
| const pred = predict(xs[iter]); | |
| const predLoss = loss(pred, ys[iter]); | |
| return predLoss; | |
| }, true); | |
| if (iter % 10 == 0) { | |
| cost.data().then((data) => console.log(`Iteration: ${iter} Loss: ${data}`)); | |
| } | |
| await dl.nextFrame(); | |
| } | |
| done(); | |
| } | |
| /* | |
| * This function calculates the accuracy of our model | |
| */ | |
| function test(xs, ys) { | |
| dl.tidy(() => { | |
| const predictedYs = xs.map((x) => Math.round(predict(x).dataSync())); | |
| var predicted = 0; | |
| for (let i = 0; i < xs.length; i++) { | |
| if (ys[i] == predictedYs[i]) { | |
| predicted++; | |
| } | |
| } | |
| console.log(`Num correctly predicted: ${predicted} out of ${xs.length}`); | |
| console.log(`Accuracy: ${predicted/xs.length}`); | |
| }) | |
| } | |
| /* | |
| * This function returns a random sample and its corresponding label | |
| */ | |
| function getRandomSample() { | |
| let x; | |
| x = [Math.random()*2-1, Math.random()*2-1]; | |
| let y; | |
| if (x[0] > 0 && x[1] > 0 || x[0] < 0 && x[1] < 0) { | |
| y = 0; | |
| } else { | |
| y = 1; | |
| } | |
| return [x, y]; | |
| } | |
| /* | |
| * This function returns n random samples | |
| */ | |
| function getNRandomSamples(n) { | |
| let xs = []; | |
| let ys = []; | |
| for (let iter = 0; iter < n; iter++) { | |
| let x, y; | |
| [x, y] = getRandomSample(); | |
| xs.push(x); | |
| ys.push(y); | |
| } | |
| return [xs, ys]; | |
| } | |
| let testX, testY; | |
| [testX, testY] = getNRandomSamples(100); | |
| // Test before training | |
| console.log(`Before training: `); | |
| test(testX, testY); | |
| console.log('============='); | |
| console.log(`Training ${num_iterations} epochs...`); | |
| // Train, then test right after | |
| train(num_iterations, () => { | |
| console.log('============='); | |
| console.log( | |
| `After training:`) | |
| test(testX, testY); | |
| }); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment