-
Notifications
You must be signed in to change notification settings - Fork 36
/
init.lua
45 lines (40 loc) · 1.69 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
require 'nn'
require 'optim'
require 'libunsup'
-- extra modules
torch.include('unsup', 'Diag.lua')
-- classes that implement algorithms
torch.include('unsup', 'UnsupModule.lua')
torch.include('unsup', 'AutoEncoder.lua')
torch.include('unsup', 'SparseAutoEncoder.lua')
torch.include('unsup', 'FistaL1.lua')
torch.include('unsup', 'LinearFistaL1.lua')
torch.include('unsup', 'SpatialConvFistaL1.lua')
torch.include('unsup', 'psd.lua')
torch.include('unsup', 'LinearPsd.lua')
torch.include('unsup', 'ConvPsd.lua')
torch.include('unsup', 'UnsupTrainer.lua')
torch.include('unsup', 'pca.lua')
torch.include('unsup', 'kmeans.lua')
torch.include('unsup', 'whitening.lua')
local oldhessian = nn.hessian.enable
function nn.hessian.enable()
oldhessian() -- enable Hessian usage
----------------------------------------------------------------------
-- Diag
----------------------------------------------------------------------
local accDiagHessianParameters = nn.hessian.accDiagHessianParameters
local updateDiagHessianInput = nn.hessian.updateDiagHessianInput
local updateDiagHessianInputPointWise = nn.hessian.updateDiagHessianInputPointWise
local initDiagHessianParameters = nn.hessian.initDiagHessianParameters
function nn.Diag.updateDiagHessianInput(self, input, diagHessianOutput)
updateDiagHessianInput(self, input, diagHessianOutput, {'weight'}, {'weightSq'})
return self.diagHessianInput
end
function nn.Diag.accDiagHessianParameters(self, input, diagHessianOutput)
accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight'}, {'diagHessianWeight'})
end
function nn.Diag.initDiagHessianParameters(self)
initDiagHessianParameters(self,{'gradWeight'},{'diagHessianWeight'})
end
end