Skip to content

Instantly share code, notes, and snippets.

@reiinakano
Last active March 19, 2018 16:46
Show Gist options
  • Select an option

  • Save reiinakano/befceb4df788b86b1b6720f457df072d to your computer and use it in GitHub Desktop.

Select an option

Save reiinakano/befceb4df788b86b1b6720f457df072d to your computer and use it in GitHub Desktop.
dljs xor
const hiddenNumNeurons = 20;
const hidden2NumNeurons = 5;
const weights = dl.variable(dl.randomNormal([2, hiddenNumNeurons]));
const biases = dl.variable(dl.zeros([hiddenNumNeurons]));
const weights2 = dl.variable(dl.randomNormal([hiddenNumNeurons, hidden2NumNeurons]));
const biases2 = dl.variable(dl.zeros([hidden2NumNeurons]));
const outWeights = dl.variable(dl.randomNormal([hidden2NumNeurons, 1]));
const outBias = dl.variable(dl.zeros([1]));
const learningRate = 0.01;
const optimizer = dl.train.adam(learningRate);
const num_iterations = 200;
/*
* Given an input, have our model output a prediction
*/
function predict(input) {
return dl.tidy(() => {
const x = dl.tensor2d(input, [1, 2]);
const hidden = x.matMul(weights).add(biases).relu();
const hidden2 = hidden.matMul(weights2).add(biases2).relu();
const out = hidden2.matMul(outWeights).add(outBias).sigmoid().asScalar();
return out;
});
}
/*
* Calculate the loss of our model's prediction vs the actual label
*/
function loss(prediction, actual) {
// Having a good error metric is key for training a machine learning model
if (actual == 1) {
return prediction.log().neg();
} else {
return dl.scalar(1).sub(prediction).log().neg();
}
}
/*
* This function trains our model asynchronously
*/
async function train(numIterations, done) {
let currentIteration = 0;
let xs, ys, cost;
[xs, ys] = getNRandomSamples(numIterations);
for (let iter = 0; iter < numIterations; iter++) {
cost = optimizer.minimize(() => {
const pred = predict(xs[iter]);
const predLoss = loss(pred, ys[iter]);
return predLoss;
}, true);
if (iter % 10 == 0) {
cost.data().then((data) => console.log(`Iteration: ${iter} Loss: ${data}`));
}
await dl.nextFrame();
}
done();
}
/*
* This function calculates the accuracy of our model
*/
function test(xs, ys) {
dl.tidy(() => {
const predictedYs = xs.map((x) => Math.round(predict(x).dataSync()));
var predicted = 0;
for (let i = 0; i < xs.length; i++) {
if (ys[i] == predictedYs[i]) {
predicted++;
}
}
console.log(`Num correctly predicted: ${predicted} out of ${xs.length}`);
console.log(`Accuracy: ${predicted/xs.length}`);
})
}
/*
* This function returns a random sample and its corresponding label
*/
function getRandomSample() {
let x;
x = [Math.random()*2-1, Math.random()*2-1];
let y;
if (x[0] > 0 && x[1] > 0 || x[0] < 0 && x[1] < 0) {
y = 0;
} else {
y = 1;
}
return [x, y];
}
/*
* This function returns n random samples
*/
function getNRandomSamples(n) {
let xs = [];
let ys = [];
for (let iter = 0; iter < n; iter++) {
let x, y;
[x, y] = getRandomSample();
xs.push(x);
ys.push(y);
}
return [xs, ys];
}
let testX, testY;
[testX, testY] = getNRandomSamples(100);
// Test before training
console.log(`Before training: `);
test(testX, testY);
console.log('=============');
console.log(`Training ${num_iterations} epochs...`);
// Train, then test right after
train(num_iterations, () => {
console.log('=============');
console.log(
`After training:`)
test(testX, testY);
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment