diff --git a/dann.js b/dann.js index 7da97add..f602f621 100644 --- a/dann.js +++ b/dann.js @@ -352,7 +352,89 @@ class Matrix { } } } +//Plot any Dann neural Network: +class NetPlot { + constructor(x,y,w,h,nn) { + this.pos = createVector(x,y); + this.w = w; + this.h = h; + this.nn = nn; + this.spacingY = h/(this.nn.i-1); + this.layerSpacing = w/(this.nn.Layers.length-1); + console.log(this.layerSpacing) + this.bufferY = this.spacingY/2; + } + renderWeights() { + stroke(100); + for (let i = 0; i < this.nn.Layers.length; i++) { + + let layer = Matrix.toArray(this.nn.Layers[i]); + + this.spacingY = (this.h/(layer.length)); + this.bufferY = this.spacingY/2 + + if (i !== this.nn.Layers.length-1) { + + let nextLayer = Matrix.toArray(this.nn.Layers[i+1]); + let sY = (this.h/(nextLayer.length)); + let bY = sY/2; + + for (let j = 0; j < nextLayer.length; j++) { + + let x = this.pos.x+((i+1)*this.layerSpacing); + let y = this.pos.y+bY+((j)*sY); + let x2 = 0; + let y2 = 0 + + for (let k = 0; k < layer.length; k++) { + + let weights = (this.nn.weights[i]).matrix; + x2 = this.pos.x+((i)*this.layerSpacing); + y2 = this.pos.y+this.bufferY+((k)*this.spacingY); + stroke(weightToColor(weights[j][k])); + strokeWeight(map(int(weights[j][k]*1000)/1000,0,1,1,2)); + line(x,y,x2,y2); + + } + } + } + + + } + } + renderLayers() { + fill(255); + stroke(0); + strokeWeight(1) + for (let i = 0; i < this.nn.Layers.length; i++) { + + let layer = Matrix.toArray(this.nn.Layers[i]); + this.spacingY = (this.h/(layer.length)); + this.bufferY = this.spacingY/2; + for (let j = 0; j < layer.length; j++) { + + let x = this.pos.x+((i)*this.layerSpacing); + let y = this.pos.y+this.bufferY+((j)*this.spacingY); + + ellipse(x,y,8,8); + + } + } + } + render() { + noFill(); + stroke(0); + rect(this.pos.x,this.pos.y,this.w,this.h); + if (dragged&&mouseX >= this.pos.x && mouseX<=this.pos.x+this.w&&mouseY >= this.pos.y&&mouseY<=this.pos.y+this.h) { + this.pos.x = mouseX-(this.w/2); + this.pos.y = mouseY-(this.h/2); + } + this.renderWeights(); + this.renderLayers(); + + } +} // Graph (graph any values over time): class Graph { constructor(x,y,w,h) { @@ -373,7 +455,7 @@ class Graph { update() { noFill(); rect(this.pos.x,this.pos.y,this.w,this.h); - if (this.dragged&&mouseX >= this.pos.x && mouseX<=this.pos.x+this.w&&mouseY >= this.pos.y&&mouseY<=this.pos.y+this.h) { + if (dragged&&mouseX >= this.pos.x && mouseX<=this.pos.x+this.w&&mouseY >= this.pos.y&&mouseY<=this.pos.y+this.h) { this.pos.x = mouseX-(this.w/2); this.pos.y = mouseY-(this.h/2); }