Last active
March 11, 2016 21:30
-
-
Save Bcohn/2ee6efa8ea854f4565c7 to your computer and use it in GitHub Desktop.
softmax
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| fruit | width | height | |
|---|---|---|---|
| 1 | 8.4 | 7.3 | |
| 1 | 8 | 6.8 | |
| 1 | 7.4 | 7.2 | |
| 1 | 7.1 | 7.8 | |
| 1 | 7.4 | 7 | |
| 1 | 6.9 | 7.3 | |
| 1 | 7.1 | 7.6 | |
| 1 | 7 | 7.1 | |
| 1 | 7.3 | 7.7 | |
| 1 | 7.6 | 7.3 | |
| 1 | 7.7 | 7.1 | |
| 1 | 7.6 | 7.5 | |
| 1 | 7.5 | 7.6 | |
| 1 | 7.5 | 7.1 | |
| 1 | 7.4 | 7.2 | |
| 1 | 7.5 | 7.5 | |
| 1 | 7.4 | 7.4 | |
| 1 | 7.3 | 7.1 | |
| 1 | 7.6 | 7.9 | |
| 2 | 6.2 | 4.7 | |
| 2 | 6 | 4.6 | |
| 2 | 5.8 | 4.3 | |
| 2 | 5.9 | 4.3 | |
| 2 | 5.8 | 4 | |
| 2 | 9 | 9.4 | |
| 2 | 9.2 | 9.2 | |
| 2 | 9.6 | 9.2 | |
| 2 | 7.5 | 9.2 | |
| 2 | 6.7 | 7.1 | |
| 2 | 7 | 7.4 | |
| 2 | 7.1 | 7.5 | |
| 2 | 7.8 | 8 | |
| 2 | 7.2 | 7 | |
| 2 | 7.5 | 8.1 | |
| 2 | 7.6 | 7.8 | |
| 2 | 7.1 | 7.9 | |
| 2 | 7.1 | 7.6 | |
| 2 | 7.3 | 7.3 | |
| 2 | 7.2 | 7.8 | |
| 2 | 6.8 | 7.4 | |
| 2 | 7.1 | 7.5 | |
| 2 | 7.6 | 8.2 | |
| 2 | 7.2 | 7.2 | |
| 3 | 7.2 | 10.3 | |
| 3 | 7.3 | 10.5 | |
| 3 | 7.2 | 9.2 | |
| 3 | 7.3 | 10.2 | |
| 3 | 7.3 | 9.7 | |
| 3 | 7.3 | 10.1 | |
| 3 | 5.8 | 8.7 | |
| 3 | 6 | 8.2 | |
| 3 | 6 | 7.5 | |
| 3 | 5.9 | 8 | |
| 3 | 6 | 8.4 | |
| 3 | 6.1 | 8.5 | |
| 3 | 6.3 | 7.7 | |
| 3 | 5.9 | 8.1 | |
| 3 | 6.5 | 8.5 | |
| 3 | 6.1 | 8.1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| <!DOCTYPE html> | |
| <meta charset="utf-8"> | |
| <link href='https://fonts.googleapis.com/css?family=Open+Sans' rel='stylesheet' type='text/css'> | |
| <style> | |
| path { | |
| fill: none; | |
| stroke: #666; | |
| stroke-width: 1.5px; | |
| } | |
| .yAxis text { | |
| font: 10px sans-serif; | |
| } | |
| .xAxis text { | |
| font: 10px sans-serif; | |
| } | |
| .linegraph{ | |
| display: block; | |
| margin: auto; | |
| height: 250px; | |
| width 960px; | |
| } | |
| .axis path, | |
| .axis line { | |
| fill: none; | |
| stroke: #808080; | |
| stroke-width: 1; | |
| shape-rendering: crispEdges; | |
| } | |
| /*.chart {width: 100px; height: 100px;} | |
| */ | |
| .xAxis path, .xAxis line { | |
| fill: none; | |
| stroke: #000; | |
| shape-rendering: crispEdges; | |
| } | |
| / | |
| .but{width:100px;} | |
| .dot{ | |
| opacity:.7; | |
| fill: #1A2F4B; | |
| } | |
| text{ | |
| font-family: 'Open Sans', sans-serif; | |
| } | |
| </style> | |
| <body> | |
| <div class="ui-widget"> | |
| <input class= "but" id = "Iterations" type="search" placeholder="Iterations (10)" required> | |
| <input class= "but" id = "Eta" type="search" placeholder="Learning Rate (.01)" required> | |
| <input class= "but" id = "Lambda" type="search" placeholder="Penalty (.001)" required> | |
| <input class="icon-search" type="button" value="Fit"> | |
| </div> | |
| <br> | |
| <div id="graph1" class="linegraph"></div> | |
| <br> | |
| <div id="chart"></div> | |
| <script src="http://cdnjs.cloudflare.com/ajax/libs/mathjs/2.7.0/math.min.js"></script> | |
| <script src="//code.jquery.com/jquery-1.10.2.js"></script> | |
| <script src="//code.jquery.com/ui/1.11.4/jquery-ui.js"></script> | |
| <!-- Latest compiled and minified JavaScript --> | |
| <script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" crossorigin="anonymous"></script> | |
| <script src="//d3js.org/d3.v3.min.js"></script> | |
| <script src="//d3js.org/topojson.v1.min.js"></script> | |
| <script src="https://gitcdn.github.io/bootstrap-toggle/2.2.0/js/bootstrap-toggle.min.js"></script> | |
| <script src="http://labratrevenge.com/d3-tip/javascripts/d3.tip.v0.6.3.js"></script> | |
| <script src="//code.jquery.com/jquery-1.11.3.min.js"></script> | |
| <script type="text/javascript" src="js/jquery.auto-complete.min.js"></script> | |
| <link rel="stylesheet" type="text/css" href="js/jquery.auto-complete.css" /> | |
| <script> | |
| (function(){ | |
| $(".icon-search").click(function() { | |
| var itsearchstring = $('#Iterations'); | |
| Iterations=itsearchstring.focus().val(); | |
| var lambdastring = $('#Lambda'); | |
| lambda=lambdastring.focus().val(); | |
| var etasearchstring = $('#Eta'); | |
| eta=etasearchstring.focus().val(); | |
| Errorlist=[]; | |
| if(Iterations){ | |
| Iterations=Iterations; | |
| } | |
| else{ | |
| Iterations=10; | |
| } | |
| if(lambda){ | |
| lambda=lambda; | |
| } | |
| else{ | |
| lambda=.001 | |
| } | |
| if(eta){ | |
| eta=eta; | |
| } | |
| else{ | |
| eta=.01 | |
| } | |
| FitModel() | |
| }); | |
| var alignthing = d3.select(".ui-widget") | |
| .attr("align","center") | |
| var alignthing = d3.select(".linegraph") | |
| .attr("align","center") | |
| var margin = {top: 20, right: 20, bottom: 30, left: 40}, | |
| width = 960 - margin.left - margin.right, | |
| height = 500 - margin.top - margin.bottom; | |
| var x = d3.scale.linear() | |
| .range([0, width]); | |
| var y = d3.scale.linear() | |
| .range([height, 0]); | |
| var color = d3.scale.ordinal() | |
| .range([d3.rgb("red"),d3.rgb("green"),d3.rgb("blue")]); | |
| var xAxis = d3.svg.axis() | |
| .scale(x) | |
| .orient("bottom"); | |
| var yAxis = d3.svg.axis() | |
| .scale(y) | |
| .orient("left"); | |
| var graph = d3.select(graph1) | |
| .append("svg:svg") | |
| .attr("width",width) | |
| .attr("height",150) | |
| .append("g") | |
| .attr("transform", "translate(" + margin.left + "," + 50 + ")"); | |
| var svg = d3.select("#chart") | |
| .attr("align","center") | |
| .append("svg") | |
| .attr("class", "axis") | |
| .attr("width",width + margin.left+margin.right ) | |
| .attr("height", height + margin.top+margin.bottom) | |
| .append("g") | |
| .attr("transform", "translate(" + margin.left + "," + margin.top + ")"); | |
| // create a line object that represents the SVN line we're creating | |
| d3.csv("fruit.csv", function(error, dataInput) { | |
| if (error) throw error; | |
| dataInput.forEach(function(d) { | |
| d.fruit =+d.fruit | |
| d.fruit = fruitmap(d.fruit) | |
| d.width = +d.width; | |
| d.height = +d.height; | |
| }); | |
| create(dataInput); | |
| }); | |
| function FitModel(){ | |
| d3.csv("fruit.csv", function(error, dataInput) { | |
| if (error) throw error; | |
| Yhotencode=[]; | |
| dataInput.forEach(function(d) { | |
| d.fruit =+d.fruit | |
| d.width = +d.width; | |
| d.height = +d.height; | |
| Yhotencode.push(OneHotEncoder(+d.fruit)); | |
| }); | |
| var Xwidth=[]; | |
| dataInput.forEach(function(d) { | |
| Xwidth.push(+d.width); | |
| }); | |
| var Xheight=[]; | |
| dataInput.forEach(function(d) { | |
| Xheight.push(+d.height); | |
| }); | |
| var Xmatrix=math.transpose(math.matrix([Xwidth,Xheight])); | |
| var Ymatrix=math.matrix(Yhotencode); | |
| var Weights = math.matrix([[1, 1,1], [1, 1,1]]); | |
| results=GradientDescent(Ymatrix,Xmatrix,Weights,Iterations); | |
| Buckets(results); | |
| }); | |
| } | |
| function GradientDescent(Ymatrix,Xmatrix,Weights,Iterations){ | |
| Iterations-=1; | |
| astart=math.multiply(Xmatrix, Weights); | |
| for(var i=0;i<59;i++){ | |
| astart.subset(math.index(i, [0,1,2]),Softmax(astart.subset(math.index(i, [0,1,2]))._data[0])); | |
| } | |
| var Yt=math.subtract(astart,Ymatrix); | |
| ErrorGrad=math.matrix([ | |
| [math.multiply(math.transpose(Yt.subset(math.index(math.range(0,59), 0))),Xmatrix.subset(math.index(math.range(0,59), 0))), | |
| math.multiply(math.transpose(Yt.subset(math.index(math.range(0,59), 1))),Xmatrix.subset(math.index(math.range(0,59), 0))), | |
| math.multiply(math.transpose(Yt.subset(math.index(math.range(0,59), 2))),Xmatrix.subset(math.index(math.range(0,59), 0)))], | |
| [math.multiply(math.transpose(Yt.subset(math.index(math.range(0,59), 0))),Xmatrix.subset(math.index(math.range(0,59), 1))), | |
| math.multiply(math.transpose(Yt.subset(math.index(math.range(0,59), 1))),Xmatrix.subset(math.index(math.range(0,59), 1))), | |
| math.multiply(math.transpose(Yt.subset(math.index(math.range(0,59), 2))),Xmatrix.subset(math.index(math.range(0,59), 1)))]]) | |
| Errorlist.push(Error(astart,Ymatrix)); | |
| ErrorGrad= ErrorGrad.map(function(d){return d/59}); | |
| var ErrorGradRegularized=math.add(ErrorGrad,math.multiply(Weights,lambda)); | |
| Weights=math.subtract(Weights,math.multiply(ErrorGradRegularized,eta)); | |
| if(Iterations>0){ | |
| // console.log(Weights._data[0]) | |
| return GradientDescent(Ymatrix,Xmatrix,Weights,Iterations); | |
| } | |
| if(Iterations===0){ | |
| Errorline(Errorlist);} | |
| return Weights; | |
| } | |
| function Error(astart,Ymatrix){ | |
| var yln = astart.map(function (value, index, matrix) { | |
| return Math.log(value); | |
| }); | |
| var Errors= math.dotMultiply(Ymatrix, yln) | |
| var count=0 | |
| var ErrorFinal= Errors.map(function (value, index, matrix) { | |
| count += value; | |
| return -count; | |
| }); | |
| return ErrorFinal._data[58][2]; | |
| } | |
| function OneHotEncoder(d){ | |
| if(d===1){ | |
| return [1,0,0]; | |
| } | |
| if(d===2){ | |
| return [0,1,0]; | |
| } | |
| else return [0,0,1]; | |
| } | |
| function fruitmap(fruit){ | |
| if (fruit===1){ | |
| return "Apple"; | |
| } | |
| if (fruit===2){ | |
| return "Orange"; | |
| } | |
| else { | |
| return "Lemon"; | |
| } | |
| } | |
| function Softmax(Xdata){ | |
| var Explist=Xdata.map(Math.exp); | |
| var Sum=Explist.reduce(function(a, b) { | |
| return a + b;}); | |
| var SoftmaxData= Explist.map(function(d){return d/Sum}); | |
| return SoftmaxData; | |
| } | |
| function create(dataInput){ | |
| var xlimits=d3.extent(dataInput, function(d) { return d.width;}) | |
| var ylimits=d3.extent(dataInput, function(d) { return d.height;}) | |
| ylimits[0]=3.9 | |
| x.domain(xlimits).nice(); | |
| y.domain(ylimits).nice(); | |
| var fruit=["Apples","Oranges","Lemons"]; | |
| svg.append("g") | |
| .attr("class", "x-axis") | |
| .attr("transform", "translate(0," + height + ")") | |
| .call(xAxis) | |
| .append("text") | |
| .attr("class", "label") | |
| .attr("x", width) | |
| .attr("y", -6) | |
| .style("text-anchor", "end") | |
| .text("Width"); | |
| svg.append("g") | |
| .attr("class", "y-axis") | |
| .call(yAxis) | |
| .append("text") | |
| .attr("class", "label") | |
| .attr("transform", "rotate(-90)") | |
| .attr("y", 6) | |
| .attr("dy", ".71em") | |
| .style("text-anchor", "end") | |
| .text("Height") | |
| svg.selectAll(".dot") | |
| .data(dataInput) | |
| .enter().append("circle") | |
| .attr("class", "dot") | |
| .attr("r", 3.5) | |
| .attr("cx", function(d) { return x(d.width); }) | |
| .attr("cy", function(d) { return y(d.height); }) | |
| .style("fill", function(d) { return color(d.fruit); }); | |
| var legend = svg.selectAll(".legend") | |
| .data(color.domain()) | |
| .enter().append("g") | |
| .attr("class", "legend") | |
| .attr("transform", function(d, i) { return "translate(0," + i * 20 + ")"; }); | |
| legend.append("rect") | |
| .attr("x", width - 18) | |
| .attr("width", 18) | |
| .attr("height", 18) | |
| .style("fill", color); | |
| legend.append("text") | |
| .attr("x", width - 24) | |
| .attr("y", 9) | |
| .attr("dy", ".35em") | |
| .style("text-anchor", "end") | |
| .text(function(d) { return d; }); | |
| } | |
| function Errorline(Errorlist){ | |
| var gx = d3.scale.linear().domain([0, Errorlist.length]).range([0, 900]); | |
| var gy = d3.scale.linear().domain([Math.max(...Errorlist),0]).range([0, 100]); | |
| var gxAxis = d3.svg.axis() | |
| .scale(gx) | |
| .orient("bottom"); | |
| var gyAxis = d3.svg.axis() | |
| .scale(gy) | |
| .orient("left"); | |
| var line = d3.svg.line() | |
| .x(function(d,i) { | |
| return gx(i)-20; | |
| }) | |
| .y(function(d) { | |
| return gy(d); | |
| }) | |
| graph.select("path") | |
| .remove(); | |
| var path = graph.append("path") | |
| .attr("d", line(Errorlist)) | |
| .attr("stroke", "steelblue") | |
| .attr("stroke-width", "2") | |
| .attr("fill", "none"); | |
| var totalLength = path.node().getTotalLength(); | |
| path | |
| .attr("stroke-dasharray", totalLength + " " + totalLength) | |
| .attr("stroke-dashoffset", totalLength) | |
| .transition() | |
| .duration(2000) | |
| .ease("linear") | |
| .attr("stroke-dashoffset", 0); | |
| graph.select("text") | |
| .remove(); | |
| var graph_title= graph.append("text") | |
| .attr("transform", "translate(" + (width-90+3) + "," + gy(Errorlist[Errorlist.length-1]) + ")") | |
| .attr("font-size", "10px") | |
| .text("Error"); | |
| // graph.append("g") | |
| // .attr("class", "x-axis") | |
| // .attr("transform", "translate(0," + height + ")") | |
| // .call(gxAxis) | |
| // .append("text") | |
| // .attr("class", "label") | |
| // .attr("x", width) | |
| // .attr("y", -6) | |
| // .style("text-anchor", "end") | |
| // .text("Width"); | |
| // graph.append("g") | |
| // .attr("class", "y-axis") | |
| // .call(gyAxis) | |
| // .append("text") | |
| // .attr("class", "label") | |
| // .attr("transform", "rotate(-90)") | |
| // .attr("y", 6) | |
| // .attr("dy", ".71em") | |
| // .style("text-anchor", "end") | |
| // .text("Height") | |
| } | |
| function Predict(Datum,Weights){ | |
| class0prob=Datum[0]*Weights._data[0][0]+Datum[1]*Weights._data[1][0] | |
| class1prob=Datum[0]*Weights._data[0][1]+Datum[1]*Weights._data[1][1] | |
| class2prob=Datum[0]*Weights._data[0][2]+Datum[1]*Weights._data[1][2] | |
| classes=[class0prob,class1prob,class2prob] | |
| var Prediction= Softmax(classes); | |
| console.log(Prediction); | |
| return Prediction | |
| } | |
| function Buckets(Weights){ | |
| var rawdata = []; | |
| for (var i = 0; i<20; i++){ | |
| for (var j = 0; j<32; j++){ | |
| rawdata.push([(i+22)/4,(j+12)/4]);}} | |
| // var obj = arr.reduce(function(o, v, i) { | |
| // o[i] = v; | |
| // return o; | |
| // }, {}); | |
| var buckets=rawdata; | |
| // x.domain([5.5,10]).nice(); | |
| // y.domain([3,11]).nice(); | |
| var xStep = .25; | |
| var yStep = .25; | |
| svg.selectAll(".tile") | |
| .remove(); | |
| svg.selectAll(".tile") | |
| .data(buckets) | |
| .enter().append("rect") | |
| .attr("class", "tile") | |
| .attr("x", function(d) { return x(d[0]); }) | |
| .attr("y", function(d) { return y(d[1] + yStep); }) | |
| .attr("width", x(xStep) - x(0)) | |
| .attr("height", y(0) - y(yStep)) | |
| .style("opacity",function(d){return Math.max(...Predict(d,Weights))}) | |
| .style("fill", function(d) {return colorinterpolation(Predict(d,Weights))}); | |
| } | |
| function findIndexOfGreatest(array) { | |
| var greatest; | |
| var indexOfGreatest; | |
| for (var i = 0; i < array.length; i++) { | |
| if (!greatest || array[i] > greatest) { | |
| greatest = array[i]; | |
| indexOfGreatest = i; | |
| } | |
| } | |
| return indexOfGreatest; | |
| } | |
| function colorinterpolation(ColorVector){ | |
| var i = d3.interpolateLab("red","green"); | |
| firstRatio=ColorVector[0]/(ColorVector[0]+ColorVector[1]); | |
| console.log(firstRatio); | |
| firstcolor=i(firstRatio); | |
| console.log(firstcolor); | |
| var i2 = d3.interpolateLab(firstcolor,"blue"); | |
| finalcolor=i2(ColorVector[2]) | |
| console.log(finalcolor); | |
| return finalcolor; | |
| } | |
| })(); | |
| </script> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment