diff --git a/examples/classification/SequenceCountInterpretabilityDemo.py b/examples/classification/SequenceCountInterpretabilityDemo.py index 1eaed0f9..5251ab10 100644 --- a/examples/classification/SequenceCountInterpretabilityDemo.py +++ b/examples/classification/SequenceCountInterpretabilityDemo.py @@ -25,7 +25,7 @@ def main(args): experiment_results = metrics(args) X_train = np.zeros((args.examples, 1, args.sequence_length, 2), dtype=np.uint32) - Y_train = np.random.randint(0, 3, size=(args.examples), dtype=np.uint32) + Y_train = np.random.randint(0, 4, size=(args.examples), dtype=np.uint32) for i in range(args.examples): position_1 = np.random.randint(0, args.sequence_length-4) position_2 = position_1+1 @@ -67,7 +67,7 @@ def main(args): X_train[i,0,position_5,0] = 0 X_train[i,0,position_5,1] = 0 - else: + elif Y_train[i] == 2: X_train[i,0,position_1,0] = 1 X_train[i,0,position_1,1] = 0 @@ -82,13 +82,27 @@ def main(args): X_train[i,0,position_5,0] = 0 X_train[i,0,position_5,1] = 0 + else: + X_train[i,0,position_1,0] = 1 + X_train[i,0,position_1,1] = 0 + X_train[i,0,position_2,0] = 0 + X_train[i,0,position_2,1] = 0 + + X_train[i,0,position_3,0] = 0 + X_train[i,0,position_3,1] = 0 + + X_train[i,0,position_4,0] = 0 + X_train[i,0,position_4,1] = 0 + + X_train[i,0,position_5,0] = 0 + X_train[i,0,position_5,1] = 0 if np.random.rand() <= args.noise: - Y_train[i] = np.random.choice(np.setdiff1d([0,1,2], [Y_train[i]])) + Y_train[i] = np.random.choice(np.setdiff1d([0,1,2,3], [Y_train[i]])) X_test = np.zeros((args.examples//10, 1, args.sequence_length, 2), dtype=np.uint32) - Y_test = np.random.randint(0, 3, size=(args.examples//10), dtype=np.uint32) + Y_test = np.random.randint(0, 4, size=(args.examples//10), dtype=np.uint32) for i in range(args.examples//10): position_1 = np.random.randint(0, args.sequence_length-4) position_2 = position_1+1 @@ -128,6 +142,21 @@ def main(args): X_test[i,0,position_4,0] = 0 X_test[i,0,position_4,1] = 0 + X_test[i,0,position_5,0] = 0 + X_test[i,0,position_5,1] = 0 + elif Y_test[i] == 2: + X_test[i,0,position_1,0] = 1 + X_test[i,0,position_1,1] = 0 + + X_test[i,0,position_2,0] = 1 + X_test[i,0,position_2,1] = 0 + + X_test[i,0,position_3,0] = 0 + X_test[i,0,position_3,1] = 0 + + X_test[i,0,position_4,0] = 0 + X_test[i,0,position_4,1] = 0 + X_test[i,0,position_5,0] = 0 X_test[i,0,position_5,1] = 0 else: @@ -135,7 +164,7 @@ def main(args): X_test[i,0,position_1,1] = 0 X_test[i,0,position_2,0] = 0 - X_test[i,0,position_2,1] = 1 + X_test[i,0,position_2,1] = 0 X_test[i,0,position_3,0] = 0 X_test[i,0,position_3,1] = 0