Skip to content

Commit 07cc815

Browse files
committed
add a simple example for using matrix multiplication layer
1 parent cc4f1f1 commit 07cc815

File tree

7 files changed

+165
-0
lines changed

7 files changed

+165
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import caffe
2+
import numpy as np
3+
4+
class LineDataLayer(caffe.Layer):
5+
def setup(self, bottom, top):
6+
self.batch_size = 10;
7+
top[0].reshape(self.batch_size, 2, 1)
8+
top[1].reshape(self.batch_size)
9+
10+
self.W = np.array([2,5], dtype=np.float32)
11+
self.b = -7.0
12+
print("W=\n"+str(self.W))
13+
print("b="+str(self.b))
14+
15+
def forward(self, bottom, top):
16+
for b in xrange(self.batch_size):
17+
x = np.random.rand(2).astype(np.float32)
18+
y = self.W.dot(x)+self.b
19+
top[0].data[b,...]=x.reshape(2,1)
20+
top[1].data[b,...]=y
21+
22+
def reshape(self, bottom, top):
23+
pass
24+
25+
def backward(self, top, propagate_down, bottom):
26+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
train_net: "train_line_fitting_ip.ptt"
2+
lr_policy: "poly"
3+
base_lr: 1.0e-1
4+
power: 1
5+
gamma: 0.1
6+
iter_size: 10
7+
stepsize: 2000
8+
average_loss: 20
9+
display: 1
10+
max_iter: 5000
11+
momentum: 0.9
12+
weight_decay: 0
13+
snapshot: 5000
14+
snapshot_prefix: "train_line_fitting_ip"
15+
solver_mode: GPU
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
import sys
3+
import caffe
4+
5+
# init
6+
caffe.set_mode_gpu()
7+
caffe.set_device(0)
8+
9+
solver = caffe.SGDSolver("solve_ip.ptt")
10+
print("W_0=\n")
11+
print(np.array(solver.net.params['ip_WXpb'][0].data))
12+
solver.solve()
13+
print("W=\n")
14+
print(np.array(solver.net.params['ip_WXpb'][0].data))
15+
print("b=\n")
16+
print(np.array(solver.net.params['ip_WXpb'][1].data))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
train_net: "train_line_fitting_matmul.ptt"
2+
lr_policy: "poly"
3+
base_lr: 1.0e-1
4+
power: 1
5+
gamma: 0.1
6+
iter_size: 10
7+
stepsize: 2000
8+
average_loss: 20
9+
display: 1
10+
max_iter: 5000
11+
momentum: 0.9
12+
weight_decay: 0
13+
snapshot: 5000
14+
snapshot_prefix: "train_line_fitting_matmul"
15+
solver_mode: GPU
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import numpy as np
2+
import sys
3+
import caffe
4+
5+
# init
6+
caffe.set_mode_gpu()
7+
caffe.set_device(0)
8+
9+
solver = caffe.SGDSolver("solve_matmul.ptt")
10+
solver.net.blobs['W'].data[...]=np.random.randn(1,2)
11+
print("W_0=\n")
12+
print(np.array(solver.net.blobs['W'].data))
13+
solver.solve()
14+
print("W=\n")
15+
print(np.array(solver.net.blobs['W'].data))
16+
print("b=\n")
17+
print(solver.net.params["add_WX_b"][0].data)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
layer {
2+
name: "data"
3+
type: "Python"
4+
top: "X"
5+
top: "Y"
6+
python_param {
7+
module: "py_data_layer"
8+
layer: "LineDataLayer"
9+
}
10+
}
11+
layer {
12+
name: "ip_WXpb"
13+
type: "InnerProduct"
14+
param { decay_mult: 0 }
15+
param { decay_mult: 0 }
16+
inner_product_param {
17+
num_output: 1
18+
weight_filler {
19+
type: "gaussian"
20+
std: 0.01
21+
}
22+
bias_filler {
23+
type: "constant"
24+
value: 0
25+
}
26+
}
27+
bottom: "X"
28+
top: "WXpb" # 10x1x1
29+
}
30+
layer {
31+
name: "loss_L2"
32+
type: "EuclideanLoss"
33+
bottom: "WXpb"
34+
bottom: "Y"
35+
top: "loss"
36+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
layer {
2+
name: "data"
3+
type: "Python"
4+
top: "X"
5+
top: "Y"
6+
python_param {
7+
module: "py_data_layer"
8+
layer: "LineDataLayer"
9+
}
10+
}
11+
layer {
12+
name: "param_W"
13+
type: "Parameter"
14+
top: "W"
15+
parameter_param { shape: { dim: 1 dim: 2 } }
16+
}
17+
layer {
18+
name: "matmul_W_X"
19+
type: "MatrixMultiplication"
20+
bottom: "W"
21+
bottom: "X"
22+
top: "WX" # 10x1x1
23+
}
24+
layer {
25+
name: "add_WX_b"
26+
type: "Bias"
27+
bottom: "WX"
28+
top: "WXpb"
29+
bias_param {
30+
num_axes: 0
31+
}
32+
param { decay_mult: 0 }
33+
}
34+
layer {
35+
name: "loss_L2"
36+
type: "EuclideanLoss"
37+
bottom: "WXpb"
38+
bottom: "Y"
39+
top: "loss"
40+
}

0 commit comments

Comments
 (0)