Skip to content

Commit 5273b75

Browse files
committed
4var PID
1 parent 72be96b commit 5273b75

File tree

10 files changed

+689
-225
lines changed

10 files changed

+689
-225
lines changed

benchmarks/entropy_baseline.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,31 @@
1111
import numpy as np
1212
import infotheory
1313
import matplotlib
14-
matplotlib.use('TkAgg')
14+
15+
matplotlib.use("TkAgg")
1516
import matplotlib.pyplot as plt
1617

1718
# range of probabilties for HEADS in coin flip
18-
p = np.arange(0,1.01,0.01)
19+
p = np.arange(0, 1.01, 0.01)
1920
entropies = []
2021
for pi in p:
2122
it = infotheory.InfoTools(1, 0)
2223
it.set_equal_interval_binning([2], [0], [1])
2324

2425
# flipping coin 10000 times
2526
for _ in range(10000):
26-
if np.random.rand()<pi: it.add_data_point([0])
27-
else: it.add_data_point([1])
27+
if np.random.rand() < pi:
28+
it.add_data_point([0])
29+
else:
30+
it.add_data_point([1])
2831

2932
# estimating entropy
3033
entropies.append(it.entropy([0]))
3134

32-
plt.figure(figsize=[3,2])
35+
plt.figure(figsize=[3, 2])
3336
plt.plot(p, entropies)
34-
plt.xlabel('Probability of HEADS')
35-
plt.ylabel('Entropy')
37+
plt.xlabel("Probability of HEADS")
38+
plt.ylabel("Entropy")
3639
plt.tight_layout()
37-
#plt.savefig('./entropy_baseline.png')
40+
# plt.savefig('./entropy_baseline.png')
3841
plt.show()

benchmarks/mi_baseline.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,61 +10,63 @@
1010
import infotheory
1111
import numpy as np
1212
import matplotlib
13-
matplotlib.use('TkAgg')
13+
14+
matplotlib.use("TkAgg")
1415
import matplotlib.pyplot as plt
1516

1617

17-
def sub_plt(ind,data,mi):
18+
def sub_plt(ind, data, mi):
1819
plt.subplot(ind)
19-
plt.plot(data[0],data[1],'.',markersize=2)
20-
plt.xlim([0,1])
21-
plt.ylim([0,1])
22-
plt.xlabel('Random variable, X', fontsize=12)
23-
plt.ylabel('Random variable, Y', fontsize=12)
24-
plt.title('Mutual information\n {}'.format(np.round(mi,4)), fontsize=13)
20+
plt.plot(data[0], data[1], ".", markersize=2)
21+
plt.xlim([0, 1])
22+
plt.ylim([0, 1])
23+
plt.xlabel("Random variable, X", fontsize=12)
24+
plt.ylabel("Random variable, Y", fontsize=12)
25+
plt.title("Mutual information\n {}".format(np.round(mi, 4)), fontsize=13)
26+
2527

26-
plt.figure(figsize=[10,3])
28+
plt.figure(figsize=[10, 3])
2729

2830
# identical variables
29-
x = np.arange(0,1,1/1000)
31+
x = np.arange(0, 1, 1 / 1000)
3032
y = np.flipud(x)
3133
it = infotheory.InfoTools(2, 0)
32-
it.set_equal_interval_binning([50]*2, [0]*2, [1]*2)
33-
it.add_data(np.vstack([x,y]).T)
34-
mi = it.mutual_info([0,1])/np.log2(50)
35-
sub_plt(141,[x,y],mi)
34+
it.set_equal_interval_binning([50] * 2, [0] * 2, [1] * 2)
35+
it.add_data(np.vstack([x, y]).T)
36+
mi = it.mutual_info([0, 1]) / np.log2(50)
37+
sub_plt(141, [x, y], mi)
3638

3739
# shuffled identical variables
38-
inds = np.arange(0.05,1,0.1)
39-
x = [np.random.normal(loc=id,scale=0.015,size=[30]) for id in inds]
40+
inds = np.arange(0.05, 1, 0.1)
41+
x = [np.random.normal(loc=id, scale=0.015, size=[30]) for id in inds]
4042
x = np.asarray(x).flatten()
4143
s_inds = np.random.permutation(inds)
42-
y = [np.random.normal(loc=id,scale=0.015,size=[30]) for id in s_inds]
44+
y = [np.random.normal(loc=id, scale=0.015, size=[30]) for id in s_inds]
4345
y = np.asarray(y).flatten()
4446
it = infotheory.InfoTools(2, 0)
45-
it.set_equal_interval_binning([10]*2, [0]*2, [1]*2)
46-
it.add_data(np.vstack([x,y]).T)
47-
mi = it.mutual_info([0,1])/np.log2(10)
48-
sub_plt(142,[x,y],mi)
47+
it.set_equal_interval_binning([10] * 2, [0] * 2, [1] * 2)
48+
it.add_data(np.vstack([x, y]).T)
49+
mi = it.mutual_info([0, 1]) / np.log2(10)
50+
sub_plt(142, [x, y], mi)
4951

5052
# noisy identical variables
5153
x = np.random.rand(1000)
52-
y = x + ((np.random.rand(1000)-0.5)/10)
54+
y = x + ((np.random.rand(1000) - 0.5) / 10)
5355
it = infotheory.InfoTools(2, 0)
54-
it.set_equal_interval_binning([20]*2, [np.min(x),np.min(y)], [np.max(x),np.max(y)])
55-
it.add_data(np.vstack([x,y]).T)
56-
mi = it.mutual_info([0,1])/np.log2(20)
57-
sub_plt(143,[x,y],mi)
56+
it.set_equal_interval_binning([20] * 2, [np.min(x), np.min(y)], [np.max(x), np.max(y)])
57+
it.add_data(np.vstack([x, y]).T)
58+
mi = it.mutual_info([0, 1]) / np.log2(20)
59+
sub_plt(143, [x, y], mi)
5860

5961
# random variables
6062
x = np.random.rand(1000)
6163
y = np.random.rand(1000)
6264
it = infotheory.InfoTools(2, 0)
63-
it.set_equal_interval_binning([20]*2, [np.min(x),np.min(y)], [np.max(x),np.max(y)])
64-
it.add_data(np.vstack([x,y]).T)
65-
mi = it.mutual_info([0,1])/np.log2(20)
66-
sub_plt(144,[x,y],mi)
65+
it.set_equal_interval_binning([20] * 2, [np.min(x), np.min(y)], [np.max(x), np.max(y)])
66+
it.add_data(np.vstack([x, y]).T)
67+
mi = it.mutual_info([0, 1]) / np.log2(20)
68+
sub_plt(144, [x, y], mi)
6769

6870
plt.tight_layout()
69-
#plt.savefig('mi_baseline.png')
71+
# plt.savefig('mi_baseline.png')
7072
plt.show()

benchmarks/pid_baseline.py

Lines changed: 124 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,45 +8,149 @@
88
###############################################################################
99
import infotheory
1010

11+
# 2 input AND gate
12+
print("2-input logical AND")
13+
data = [[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 1]]
14+
15+
# creating the object and adding data
16+
it_and = infotheory.InfoTools(3, 0)
17+
it_and.set_equal_interval_binning([2] * 3, [0] * 3, [1] * 3)
18+
it_and.add_data(data)
19+
20+
# PID-ing
21+
total_mi = it_and.mutual_info([0, 0, 1])
22+
redundant_info = it_and.redundant_info([1, 2, 0])
23+
unique_1 = it_and.unique_info([1, 2, 0])
24+
unique_2 = it_and.unique_info([2, 1, 0])
25+
synergy = it_and.synergy([1, 2, 0])
26+
27+
print("total_mi = {}".format(total_mi))
28+
print("redundant_info = {}".format(redundant_info))
29+
print("unique_1 = {}".format(unique_1))
30+
print("unique_2 = {}".format(unique_2))
31+
print("synergy = {}\n".format(synergy))
32+
33+
1134
# 2 input xor gate
12-
data = [[0,0,0],[0,1,1],[1,0,1],[1,1,0]]
35+
print("2-input logical XOR")
36+
data = [[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0]]
1337

1438
# creating the object and adding data
1539
it_xor = infotheory.InfoTools(3, 0)
16-
it_xor.set_equal_interval_binning([2]*3, [0]*3, [1]*3)
40+
it_xor.set_equal_interval_binning([2] * 3, [0] * 3, [1] * 3)
1741
it_xor.add_data(data)
1842

1943
# PID-ing
20-
print("2-input XOR")
21-
total_mi = it_xor.mutual_info([0,0,1])
22-
redundant_info = it_xor.redundant_info([1,2,0])
23-
unique_1 = it_xor.unique_info([1,2,0])
24-
unique_2 = it_xor.unique_info([1,2,0])
25-
synergy = it_xor.synergy([1,2,0])
44+
total_mi = it_xor.mutual_info([0, 0, 1])
45+
redundant_info = it_xor.redundant_info([1, 2, 0])
46+
unique_1 = it_xor.unique_info([1, 2, 0])
47+
unique_2 = it_xor.unique_info([1, 2, 0])
48+
synergy = it_xor.synergy([1, 2, 0])
49+
2650
print("total_mi = {}".format(total_mi))
2751
print("redundant_info = {}".format(redundant_info))
2852
print("unique_1 = {}".format(unique_1))
2953
print("unique_2 = {}".format(unique_2))
3054
print("synergy = {}\n".format(synergy))
3155

3256

33-
# 2 input AND gate
34-
data = [[0,0,0],[0,1,0],[1,0,0],[1,1,1]]
57+
# 3 input OR
58+
print("3-input logical AND")
59+
data = [
60+
[0, 0, 0, 0],
61+
[0, 0, 1, 0],
62+
[0, 1, 0, 0],
63+
[0, 1, 1, 0],
64+
[1, 0, 0, 0],
65+
[1, 0, 1, 0],
66+
[1, 1, 0, 0],
67+
[1, 1, 1, 1],
68+
]
3569

3670
# creating the object and adding data
37-
it_and = infotheory.InfoTools(3, 0)
38-
it_and.set_equal_interval_binning([2]*3, [0]*3, [1]*3)
39-
it_and.add_data(data)
71+
it_3and = infotheory.InfoTools(4, 0)
72+
it_3and.set_equal_interval_binning([2] * 4, [0] * 4, [1] * 4)
73+
it_3and.add_data(data)
4074

4175
# PID-ing
42-
print("2-input AND")
43-
total_mi = it_and.mutual_info([0,0,1])
44-
redundant_info = it_and.redundant_info([1,2,0])
45-
unique_1 = it_and.unique_info([1,2,0])
46-
unique_2 = it_and.unique_info([2,1,0])
47-
synergy = it_and.synergy([1,2,0])
76+
total_mi = it_3and.mutual_info([1, 1, 1, 0])
77+
mi_12 = it_3and.mutual_info([1, 1, -1, 0])
78+
mi_13 = it_3and.mutual_info([1, -1, 1, 0])
79+
mi_23 = it_3and.mutual_info([-1, 1, 1, 0])
80+
mi_1 = it_3and.mutual_info([1, -1, -1, 0])
81+
mi_2 = it_3and.mutual_info([-1, 1, -1, 0])
82+
mi_3 = it_3and.mutual_info([-1, -1, 1, 0])
83+
redundant_info = it_3and.redundant_info([1, 2, 3, 0])
84+
redundant_12 = it_3and.redundant_info([1, 2, -1, 0])
85+
redundant_13 = it_3and.redundant_info([1, -1, 2, 0])
86+
redundant_23 = it_3and.redundant_info([-1, 1, 2, 0])
87+
unique_1 = it_3and.unique_info([1, 2, 3, 0])
88+
unique_2 = it_3and.unique_info([2, 1, 3, 0])
89+
unique_3 = it_3and.unique_info([2, 3, 1, 0])
90+
synergy = it_3and.synergy([1, 2, 3, 0])
91+
synergy_12 = it_3and.synergy([1, 2, -1, 0])
92+
synergy_13 = it_3and.synergy([1, -1, 2, 0])
93+
synergy_23 = it_3and.synergy([-1, 1, 2, 0])
94+
4895
print("total_mi = {}".format(total_mi))
96+
print("mi_12 = {}".format(mi_12))
97+
print("mi_13 = {}".format(mi_13))
98+
print("mi_23 = {}".format(mi_23))
99+
print("mi_1 = {}".format(mi_1))
100+
print("mi_2 = {}".format(mi_2))
101+
print("mi_3 = {}".format(mi_3))
49102
print("redundant_info = {}".format(redundant_info))
103+
print("redundant_12 = {}".format(redundant_12))
104+
print("redundant_13 = {}".format(redundant_13))
105+
print("redundant_23 = {}".format(redundant_23))
50106
print("unique_1 = {}".format(unique_1))
51107
print("unique_2 = {}".format(unique_2))
52-
print("synergy = {}\n".format(synergy))
108+
print("unique_3 = {}".format(unique_3))
109+
print("synergy = {}".format(synergy))
110+
print("synergy_12 = {}".format(synergy_12))
111+
print("synergy_13 = {}".format(synergy_13))
112+
print("synergy_23 = {}\n".format(synergy_23))
113+
114+
# 4 var multivariate analyses
115+
print("3-input Even parity")
116+
data = [
117+
[0, 0, 0, 0],
118+
[0, 0, 1, 1],
119+
[0, 1, 0, 1],
120+
[0, 1, 1, 0],
121+
[1, 0, 0, 1],
122+
[1, 0, 1, 0],
123+
[1, 1, 0, 0],
124+
[1, 1, 1, 1],
125+
]
126+
# creating the object and adding data
127+
it_par = infotheory.InfoTools(4, 0)
128+
it_par.set_equal_interval_binning([2] * 4, [0] * 4, [1] * 4)
129+
it_par.add_data(data)
130+
131+
# PID-ing
132+
total_mi = it_par.mutual_info([1, 1, 1, 0])
133+
redundant_info = it_par.redundant_info([1, 2, 3, 0])
134+
redundant_12 = it_par.redundant_info([1, 2, -1, 0])
135+
redundant_13 = it_par.redundant_info([1, -1, 2, 0])
136+
redundant_23 = it_par.redundant_info([-1, 1, 2, 0])
137+
unique_1 = it_par.unique_info([1, 2, 3, 0])
138+
unique_2 = it_par.unique_info([2, 1, 3, 0])
139+
unique_3 = it_par.unique_info([2, 3, 1, 0])
140+
synergy = it_par.synergy([1, 2, 3, 0])
141+
synergy_12 = it_par.synergy([1, 2, -1, 0])
142+
synergy_13 = it_par.synergy([1, -1, 2, 0])
143+
synergy_23 = it_par.synergy([-1, 1, 2, 0])
144+
145+
print("total_mi = {}".format(total_mi))
146+
print("redundant_info = {}".format(redundant_info))
147+
print("redundant_12 = {}".format(redundant_12))
148+
print("redundant_13 = {}".format(redundant_13))
149+
print("redundant_23 = {}".format(redundant_23))
150+
print("unique_1 = {}".format(unique_1))
151+
print("unique_2 = {}".format(unique_2))
152+
print("unique_3 = {}".format(unique_3))
153+
print("synergy = {}".format(synergy))
154+
print("synergy_12 = {}".format(synergy_12))
155+
print("synergy_13 = {}".format(synergy_13))
156+
print("synergy_23 = {}".format(synergy_23))

benchmarks/pid_cont.png

34.2 KB
Loading

0 commit comments

Comments
 (0)