diff --git a/examples/classification/RecurrentInterpretabilityDemo.py b/examples/classification/RecurrentInterpretabilityDemo.py index 072110d3..40194d36 100644 --- a/examples/classification/RecurrentInterpretabilityDemo.py +++ b/examples/classification/RecurrentInterpretabilityDemo.py @@ -27,13 +27,11 @@ def main(args): X_train = np.random.randint(0, 2, size=(10000, args.number_of_features), dtype=np.uint32) Y_train = np.logical_xor(X_train[:, 0], X_train[:, 1]).astype(dtype=np.uint32) Y_train = np.where(np.random.rand(10000) <= args.noise, 1 - Y_train, Y_train) # Adds noise - #X_train[:,2:] = 0 X_train = X_train.reshape(-1, 1, args.number_of_features) X_test = np.random.randint(0, 2, size=(10000, args.number_of_features), dtype=np.uint32) Y_test = np.logical_xor(X_test[:, 0], X_test[:, 1]).astype(dtype=np.uint32) - #X_test[:,2:] = 0 X_test = X_test.reshape(-1, 1, args.number_of_features) @@ -145,9 +143,9 @@ def main(args): def default_args(**kwargs): parser = argparse.ArgumentParser() parser.add_argument("--epochs", default=50, type=int) - parser.add_argument("--number-of-clauses", default=50, type=int) + parser.add_argument("--number-of-clauses", default=4, type=int) parser.add_argument("--platform", default='CPU', type=str) - parser.add_argument("--T", default=10, type=int) + parser.add_argument("--T", default=2, type=int) parser.add_argument("--s", default=2.5, type=float) parser.add_argument("--number-of-features", default=6, type=int) parser.add_argument("--noise", default=0.1, type=float, help="Noisy XOR")