-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrbf.jl
108 lines (81 loc) · 1.95 KB
/
rbf.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
using JSON
using LinearAlgebra
function sqnorm(a, b)
return dot(a, a) + dot(b, b) - 2 * dot(a, b)
end
function rbf_kernel(x, y, γ)
return exp(-γ * sqnorm(x, y))
end
function rbf_compute(x)
# map the input to the hidden layer
hh = [rbf_kernel(x, c, 1.0) for c in centres]
# prepend a 1.0 to the hidden layer
# and store it in a new array
hout = [1.0; hh]
# calculate the output of the network
# and clamp it to between 0.0 and 1.0
# return dot(hout, wts)
return clamp(dot(hout, wts), 0.0, 1.0)
end
function read_dat_vector(filename::String)
values = Float64[]
f = open(filename)
while !eof(f)
push!(values, parse(Float64, readline(f)))
end
close(f)
return values
end
wts = read_dat_vector("wts.dat")
testout = read_dat_vector("testout.dat")
# print the vector
println(wts)
println(testout)
# read the centres three Float64 values per line
# and store them in an array of arrays
centres = []
f = open("centres.dat")
while !eof(f)
line = readline(f)
push!(centres, parse.(Float64, split(line)))
end
# close the file
close(f)
# process test inputs
testdata = []
f = open("testdata.dat")
while !eof(f)
line = readline(f)
push!(testdata, parse.(Float64, split(line)))
end
# close the file
close(f)
# go through each element in testdata
for entry in testdata
local x, out
x = entry[1:3]
target = entry[end]
println("x = $x, target = $target")
out = rbf_compute(x)
println("out: $out")
#=
for j in 1:length(centres)
# calculate the rbf kernel
kernel = rbf_kernel(x, centres[j], 1.0)
# multiply the kernel by the weight
out += kernel * wts[j]
end
=#
end
# export wts and centres to JSON
open("wts.json", "w") do f
JSON.print(f, wts)
end
open("centres.json", "w") do f
JSON.print(f, centres)
end
rgb = [255.0, 107.0, 0.0]
x = rgb / 255.0
println("rgb: $rgb, x: $x")
out = rbf_compute(x)
println("out: $out")