Skip to content

Commit

Permalink
first impl of linear regression synaptic ca gives a 4x speedup and le…
Browse files Browse the repository at this point in the history
…arning performance is consistently _improved_ relative to prior, across ra25, deep_fsa, and objrec, which usually means that it is genuinely better. Lots more work to be done to explore the space but this is an encouraging start!
  • Loading branch information
rcoreilly committed Jun 12, 2024
1 parent b8cac23 commit 413b3fe
Show file tree
Hide file tree
Showing 21 changed files with 125 additions and 65 deletions.
43 changes: 43 additions & 0 deletions axon/enumgen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions axon/learn.go
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,26 @@ func (ls *LRateParams) Init() {
ls.UpdateEff()
}

// SynCaFuns are different ways of computing synaptic calcium (experimental)
type SynCaFuns int32 //enums:enum

const (
// StdSynCa uses standard synaptic calcium integration method
StdSynCa SynCaFuns = iota

// LinearSynCa uses linear regression generated calcium integration (much faster)
LinearSynCa

// NeurSynCa uses simple product of separately-integrated neuron values (much faster)
NeurSynCa
)

// TraceParams manages parameters associated with temporal trace learning
type TraceParams struct {

// how to compute the synaptic calcium (experimental)
SynCa SynCaFuns

// time constant for integrating trace over theta cycle timescales -- governs the decay rate of syanptic trace
Tau float32 `default:"1,2,4"`

Expand All @@ -696,9 +713,12 @@ type TraceParams struct {

// rate = 1 / tau
Dt float32 `view:"-" json:"-" xml:"-" edit:"-"`

pad, pad1, pad2 float32
}

func (tp *TraceParams) Defaults() {
tp.SynCa = LinearSynCa
tp.Tau = 1
tp.SubMean = 0
tp.LearnThr = 0
Expand Down
55 changes: 31 additions & 24 deletions axon/pathparams.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ func (pj *PathParams) GatherSpikes(ctx *Context, ly *LayerParams, ni, di uint32,
// DoSynCa returns false if should not do synaptic-level calcium updating.
// Done by default in Cortex, not for some other special pathway types.
func (pj *PathParams) DoSynCa() bool {
if pj.PathType == RWPath || pj.PathType == TDPredPath || pj.PathType == VSMatrixPath ||
pj.PathType == DSMatrixPath || pj.PathType == VSPatchPath || pj.PathType == BLAPath {
if pj.Learn.Trace.SynCa != StdSynCa || pj.PathType == RWPath || pj.PathType == TDPredPath || pj.PathType == VSMatrixPath || pj.PathType == DSMatrixPath || pj.PathType == VSPatchPath || pj.PathType == BLAPath || pj.Learn.Hebb.On.IsTrue() {
return false
}
return true
Expand Down Expand Up @@ -338,28 +337,36 @@ func (pj *PathParams) DWtSyn(ctx *Context, syni, si, ri, di uint32, layPool, sub
// Uses synaptically integrated spiking, computed at the Theta cycle interval.
// This is the trace version for hidden units, and uses syn CaP - CaD for targets.
func (pj *PathParams) DWtSynCortex(ctx *Context, syni, si, ri, di uint32, layPool, subPool *Pool, isTarget bool) {
// credit assignment part
caUpT := SynCaV(ctx, syni, di, CaUpT) // time of last update
syCaM := SynCaV(ctx, syni, di, CaM) // fast time scale
syCaP := SynCaV(ctx, syni, di, CaP) // slower but still fast time scale, drives Potentiation
syCaD := SynCaV(ctx, syni, di, CaD) // slow time scale, drives Depression (one trial = 200 cycles)
pj.Learn.KinaseCa.CurCa(ctx.SynCaCtr, caUpT, &syCaM, &syCaP, &syCaD) // always update, getting current Ca (just optimization)

rb0 := NrnV(ctx, ri, di, SpkBin0)
sb0 := NrnV(ctx, si, di, SpkBin0)
rb1 := NrnV(ctx, ri, di, SpkBin1)
sb1 := NrnV(ctx, si, di, SpkBin1)
rb2 := NrnV(ctx, ri, di, SpkBin2)
sb2 := NrnV(ctx, si, di, SpkBin2)
rb3 := NrnV(ctx, ri, di, SpkBin3)
sb3 := NrnV(ctx, si, di, SpkBin3)

b0 := 0.1 * (rb0 * sb0)
b1 := 0.1 * (rb1 * sb1)
b2 := 0.1 * (rb2 * sb2)
b3 := 0.1 * (rb3 * sb3)

pj.Learn.KinaseCa.FinalCa(b0, b1, b2, b3, &syCaM, &syCaP, &syCaD)
var syCaM, syCaP, syCaD, caUpT float32
switch pj.Learn.Trace.SynCa {
case StdSynCa:
caUpT = SynCaV(ctx, syni, di, CaUpT) // time of last update
syCaM = SynCaV(ctx, syni, di, CaM) // fast time scale
syCaP = SynCaV(ctx, syni, di, CaP) // slower but still fast time scale, drives Potentiation
syCaD = SynCaV(ctx, syni, di, CaD) // slow time scale, drives Depression (one trial = 200 cycles)
pj.Learn.KinaseCa.CurCa(ctx.SynCaCtr, caUpT, &syCaM, &syCaP, &syCaD) // always update, getting current Ca (just optimization)
case LinearSynCa:
rb0 := NrnV(ctx, ri, di, SpkBin0)
sb0 := NrnV(ctx, si, di, SpkBin0)
rb1 := NrnV(ctx, ri, di, SpkBin1)
sb1 := NrnV(ctx, si, di, SpkBin1)
rb2 := NrnV(ctx, ri, di, SpkBin2)
sb2 := NrnV(ctx, si, di, SpkBin2)
rb3 := NrnV(ctx, ri, di, SpkBin3)
sb3 := NrnV(ctx, si, di, SpkBin3)

b0 := 0.1 * (rb0 * sb0)
b1 := 0.1 * (rb1 * sb1)
b2 := 0.1 * (rb2 * sb2)
b3 := 0.1 * (rb3 * sb3)

pj.Learn.KinaseCa.FinalCa(b0, b1, b2, b3, &syCaM, &syCaP, &syCaD)
case NeurSynCa:
gain := float32(1.0)
syCaM = gain * NrnV(ctx, si, di, CaSpkM) * NrnV(ctx, ri, di, CaSpkM)
syCaP = gain * NrnV(ctx, si, di, CaSpkP) * NrnV(ctx, ri, di, CaSpkP)
syCaD = gain * NrnV(ctx, si, di, CaSpkD) * NrnV(ctx, ri, di, CaSpkD)
}

SetSynCaV(ctx, syni, di, CaM, syCaM)
SetSynCaV(ctx, syni, di, CaP, syCaP)
Expand Down
4 changes: 1 addition & 3 deletions axon/rand.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package axon

import (
"cogentcore.org/core/vgpu/gosl/slrand"
"cogentcore.org/core/vgpu/gosl/sltype"
)

//gosl:hlsl axonrand
Expand Down Expand Up @@ -32,8 +31,7 @@ func GetRandomNumber(index uint32, counter slrand.Counter, funIndex RandFunIndex
var randCtr slrand.Counter
randCtr = counter
randCtr.Add(uint32(funIndex))
var ctr sltype.Uint2
ctr = randCtr.Uint2()
ctr := randCtr.Uint2()
return slrand.Float(&ctr, index)
}

Expand Down
2 changes: 1 addition & 1 deletion axon/shaders/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# The go generate command does this automatically.

all:
cd ../; gosl -exclude=Update,UpdateParams,Defaults,AllParams,ShouldShow cogentcore.org/core/math32/v2/fastexp.go cogentcore.org/core/etable/v2/minmax ../chans/chans.go ../chans ../kinase ../fsfffb/inhib.go ../fsfffb github.com/emer/emergent/v2/etime github.com/emer/emergent/v2/ringidx rand.go avgmax.go neuromod.go globals.go context.go neuron.go synapse.go pool.go layervals.go act.go act_prjn.go inhib.go learn.go layertypes.go layerparams.go deep_layers.go rl_layers.go pvlv_layers.go pcore_layers.go prjntypes.go prjnparams.go deep_prjns.go rl_prjns.go pvlv_prjns.go pcore_prjns.go hip_prjns.go gpu_hlsl
cd ../; gosl -exclude=Update,UpdateParams,Defaults,AllParams,ShouldShow cogentcore.org/core/math32/fastexp.go cogentcore.org/core/math32/minmax ../chans/chans.go ../chans ../kinase ../fsfffb/inhib.go ../fsfffb github.com/emer/emergent/v2/etime github.com/emer/emergent/v2/ringidx rand.go avgmax.go neuromod.go globals.go context.go neuron.go synapse.go pool.go layervals.go act.go act_prjn.go inhib.go learn.go layertypes.go layerparams.go deep_layers.go rl_layers.go pvlv_layers.go pcore_layers.go prjntypes.go prjnparams.go deep_prjns.go rl_prjns.go pvlv_prjns.go pcore_prjns.go hip_prjns.go gpu_hlsl

# note: gosl automatically compiles the hlsl files using this command:
%.spv : %.hlsl
Expand Down
Binary file modified axon/shaders/gpu_dwt.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_dwtfmdi.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_dwtsubmean.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_gather.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_newstate_pool.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_sendspike.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_synca.spv
Binary file not shown.
Binary file modified axon/shaders/gpu_wtfmdwt.spv
Binary file not shown.
4 changes: 3 additions & 1 deletion axon/typegen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion examples/deep_fsa/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ var ParamSets = netparams.Sets{
}},
{Sel: "Path", Desc: "std",
Params: params.Params{
"Path.Learn.Trace.SynCa": "LinearSynCa",
"Path.Learn.Trace.SubMean": "0", // 0 > 1 -- even with CTCtxt = 0
"Path.Learn.LRate.Base": "0.03", // .03 > others -- same as CtCtxt
"Path.Learn.LRate.Base": "0.02", // .03 > others -- same as CtCtxt
"Path.SWts.Adapt.LRate": "0.01", // 0.01 or 0.0001 music
"Path.SWts.Init.SPct": "1.0", // 1 works fine here -- .5 also ok
"Path.Com.PFail": "0.0",
Expand Down
41 changes: 14 additions & 27 deletions examples/kinaseq/kinaseq.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,15 @@ import (
"fmt"
"math/rand"
"reflect"
"strings"

"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/tensor"
"cogentcore.org/core/tensor/stats/stats"
"github.com/emer/emergent/v2/decoder"
"github.com/emer/emergent/v2/elog"
"github.com/emer/emergent/v2/etime"
)

const (
NBins = 20
CyclesPerBin = 10
NOutputs = 3
NInputs = NBins + 2 // per neuron
)

// KinaseNeuron has Neuron state
type KinaseNeuron struct {
// Neuron spiking (0,1)
Expand All @@ -42,7 +34,7 @@ type KinaseNeuron struct {
TotalSpikes float32

// binned count of spikes, for regression learning
BinnedSpikes [NBins]float32
BinnedSpikes [4]float32
}

func (kn *KinaseNeuron) Init() {
Expand All @@ -58,13 +50,14 @@ func (kn *KinaseNeuron) StartTrial() {
for i := range kn.BinnedSpikes {
kn.BinnedSpikes[i] = 0
}
// kn.CaSyn = 0 // note: better fits with carryover
}

// Cycle does one cycle of neuron updating, with given exponential spike interval
// based on target spiking firing rate.
func (kn *KinaseNeuron) Cycle(expInt float32, params *ParamConfig, cyc int) {
kn.Spike = 0
bin := cyc / CyclesPerBin
bin := cyc / 50
if expInt > 0 {
kn.SpikeP *= rand.Float32()
if kn.SpikeP <= expInt {
Expand Down Expand Up @@ -143,8 +136,11 @@ type KinaseState struct {
// Standard synapse values
StdSyn KinaseSynapse

// Linearion synapse values
// Linear synapse values
LinearSyn KinaseSynapse

// binned integration of send, recv spikes
BinnedSums [4]float32
}

func (ks *KinaseState) Init() {
Expand All @@ -160,8 +156,6 @@ func (kn *KinaseState) StartTrial() {
}

func (ss *Sim) ConfigKinase() {
ss.Linear.Init(NOutputs, NInputs*2, 0, decoder.IdentityFunc)
ss.Linear.LRate = ss.Config.Params.LRate
}

// Sweep runs a sweep through minus-plus ranges
Expand Down Expand Up @@ -268,20 +262,14 @@ func (ss *Sim) TrialImpl(minusHz, plusHz float32) {
}
ks.StdSyn.DWt = ks.StdSyn.CaP - ks.StdSyn.CaD

ks.Send.SetInput(ss.Linear.Inputs, 0)
ks.Recv.SetInput(ss.Linear.Inputs, NInputs)
ss.Linear.Forward()
out := make([]float32, NOutputs)
ss.Linear.Output(&out)
ks.LinearSyn.CaM = out[0]
ks.LinearSyn.CaP = out[1]
ks.LinearSyn.CaD = out[2]
for i := range ks.BinnedSums {
ks.BinnedSums[i] = 0.1 * (ks.Recv.BinnedSpikes[i] * ks.Send.BinnedSpikes[i])
}

ss.CaParams.FinalCa(ks.BinnedSums[0], ks.BinnedSums[1], ks.BinnedSums[2], ks.BinnedSums[3], &ks.LinearSyn.CaM, &ks.LinearSyn.CaP, &ks.LinearSyn.CaD)
ks.LinearSyn.DWt = ks.LinearSyn.CaP - ks.LinearSyn.CaD

if ks.Train {
targ := [NOutputs]float32{ks.StdSyn.CaM, ks.StdSyn.CaP, ks.StdSyn.CaD}
sse, _ := ss.Linear.Train(targ[:])
ks.SSE = sse
ss.Logs.LogRow(etime.Train, etime.Cycle, 0)
ss.GUI.UpdatePlot(etime.Train, etime.Cycle)
ss.Logs.LogRow(etime.Train, etime.Trial, ks.Trial)
Expand All @@ -308,7 +296,6 @@ func (ss *Sim) Train() {
ss.Logs.LogRow(etime.Train, etime.Condition, ss.Kinase.Condition)
ss.GUI.UpdatePlot(etime.Train, etime.Condition)
}
tensor.SaveCSV(&ss.Linear.Weights, "trained.wts", '\t')
}

func (ss *Sim) ConfigKinaseLogItems() {
Expand All @@ -320,7 +307,7 @@ func (ss *Sim) ConfigKinaseLogItems() {
tn := len(times)
WalkFields(val,
func(parent reflect.Value, field reflect.StructField, value reflect.Value) bool {
if field.Name == "BinnedSpikes" {
if strings.HasPrefix(field.Name, "Binned") {
return false
}
return true
Expand Down
4 changes: 2 additions & 2 deletions examples/kinaseq/neuron.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ func (ss *Sim) NeuronUpdate(nt *axon.Network, inputOn bool) {
syni := uint32(0)
pj := ly.RcvPaths[0]

snCaSyn := pj.Params.Learn.KinaseCa.SpikeG * axon.NrnV(ctx, ni, di, axon.CaSyn)
snCaSyn := pj.Params.Learn.KinaseCa.CaScale * axon.NrnV(ctx, ni, di, axon.CaSyn)
pj.Params.SynCaSyn(ctx, syni, ri, di, snCaSyn, updtThr)

rnCaSyn := pj.Params.Learn.KinaseCa.SpikeG * axon.NrnV(ctx, ri, di, axon.CaSyn)
rnCaSyn := pj.Params.Learn.KinaseCa.CaScale * axon.NrnV(ctx, ri, di, axon.CaSyn)
if axon.NrnV(ctx, si, di, axon.Spike) <= 0 { // NOT already handled in send version
pj.Params.SynCaSyn(ctx, syni, si, di, rnCaSyn, updtThr)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/kinaseq/sim.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type Sim struct {
Config Config

// Kinase SynCa params
CaParams kinase.CaParams
CaParams kinase.SynCaParams

// Kinase state
Kinase KinaseState
Expand Down
1 change: 1 addition & 0 deletions examples/objrec/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ var ParamSets = netparams.Sets{
}},
{Sel: "Path", Desc: "yes extra learning factors",
Params: params.Params{
"Path.Learn.Trace.SynCa": "LinearSynCa",
"Path.Learn.LRate.Base": "0.2", // 0.4 for NeuronCa; 0.2 best, 0.1 nominal
"Path.Learn.Trace.SubMean": "1", // 1 -- faster if 0 until 20 epc -- prevents sig amount of late deterioration
"Path.SWts.Adapt.LRate": "0.0001", // 0.005 == .1 == .01
Expand Down
9 changes: 5 additions & 4 deletions examples/ra25/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ var ParamSets = netparams.Sets{
}},
{Sel: "Path", Desc: "basic path params",
Params: params.Params{
"Path.Learn.LRate.Base": "0.1", // 0.1 learns fast but dies early, .02 is stable long term
"Path.SWts.Adapt.LRate": "0.1", // .1 >= .2,
"Path.SWts.Init.SPct": "0.5", // .5 >= 1 here -- 0.5 more reliable, 1.0 faster..
"Path.Learn.Trace.SubMean": "0", // 1 > 0 for long run stability
"Path.Learn.Trace.SynCa": "LinearSynCa",
"Path.Learn.LRate.Base": "0.05", // 0.1 learns fast but dies early, .02 is stable long term
"Path.SWts.Adapt.LRate": "0.1", // .1 >= .2,
"Path.SWts.Init.SPct": "0.5", // .5 >= 1 here -- 0.5 more reliable, 1.0 faster..
"Path.Learn.Trace.SubMean": "0", // 1 > 0 for long run stability
}},
{Sel: ".BackPath", Desc: "top-down back-pathways MUST have lower relative weight scale, otherwise network hallucinates",
Params: params.Params{
Expand Down
2 changes: 1 addition & 1 deletion kinase/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (kp *SynCaParams) CurCa(ctime, utime float32, caM, caP, caD *float32) {

// FinalCa uses a linear regression to compute the final Ca values
func (kp *SynCaParams) FinalCa(bin0, bin1, bin2, bin3 float32, caM, caP, caD *float32) {
if bin0+bin1+bin2+bin3 < 0.1 {
if bin0+bin1+bin2+bin3 < 0.01 {
*caM = 0
*caP = 0
*caD = 0
Expand Down

0 comments on commit 413b3fe

Please sign in to comment.