Skip to content

Commit

Permalink
Preparing for recurrent TM
Browse files Browse the repository at this point in the history
  • Loading branch information
olegranmo committed May 9, 2024
1 parent 6f6a214 commit dec49d0
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 5 deletions.
147 changes: 147 additions & 0 deletions examples/classification/RecurrentInterpretabilityDemo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import argparse
import logging

from tmu.models.classification.vanilla_classifier import TMClassifier
import numpy as np

_LOGGER = logging.getLogger(__name__)

def metrics(args):
return dict(
accuracy=[],
class_0_precision_positive=[],
class_0_recall_positive=[],
class_0_recall_negative=[],
class_0_precision_negative=[],
class_1_precision_positive=[],
class_1_recall_positive=[],
class_1_recall_negative=[],
class_1_precision_negative=[],
literal_frequency=None,
args=vars(args)
)

def main(args):
experiment_results = metrics(args)

X_train = np.random.randint(0, 2, size=(5000, args.number_of_features), dtype=np.uint32)
Y_train = np.logical_and(X_train[:, 0], X_train[:, 1]).astype(dtype=np.uint32)
Y_train = np.where(np.random.rand(5000) <= args.noise, 1 - Y_train, Y_train) # Adds noise
X_train = X_train.reshape(-1, 1, args.number_of_features)

X_test = np.random.randint(0, 2, size=(5000, args.number_of_features), dtype=np.uint32)
Y_test = np.logical_and(X_test[:, 0], X_test[:, 1]).astype(dtype=np.uint32)
X_test = X_test.reshape(-1, 1, args.number_of_features)

tm = TMClassifier(args.number_of_clauses, args.T, args.s, patch_dim=(1, 1), weighted_clauses=True, platform=args.platform, boost_true_positive_feedback=0, recurrent=True, incremental=False)

for i in range(20):
tm.fit(X_train, Y_train)
accuracy = 100 * (tm.predict(X_test) == Y_test).mean()
experiment_results["accuracy"].append(accuracy)
print("Accuracy:", accuracy)

np.set_printoptions(threshold=np.inf, linewidth=200, precision=2, suppress=True)

print("\nClass 0 Positive Clauses:\n")
precision = tm.clause_precision(0, 0, X_test, Y_test)
recall = tm.clause_recall(0, 0, X_test, Y_test)
experiment_results["class_0_precision_positive"].append(list(np.asarray(precision)))
experiment_results["class_0_recall_positive"].append(list(np.asarray(recall)))

for j in range(args.number_of_clauses // 2):
print("Clause #%d W:%d P:%.2f R:%.2f " % (j, tm.get_weight(0, 0, j), precision[j], recall[j]), end=' ')
l = []
for k in range(tm.clause_banks[0].number_of_features * 2):
if tm.get_ta_action(j, k, the_class=0, polarity=0):
if k < tm.clause_banks[0].number_of_features:
l.append(" x%d(%d)" % (k, tm.get_ta_state(j, k, the_class=0, polarity=0)))
else:
l.append("¬x%d(%d)" % (k - tm.clause_banks[0].number_of_features, tm.get_ta_state(j, k, the_class=0, polarity=0)))
print(" ∧ ".join(l))

print("\nClass 0 Negative Clauses:\n")

precision = tm.clause_precision(0, 1, X_test, Y_test)
recall = tm.clause_recall(0, 1, X_test, Y_test)
experiment_results["class_0_precision_negative"].append(list(np.asarray(precision)))
experiment_results["class_0_recall_negative"].append(list(np.asarray(recall)))

for j in range(args.number_of_clauses // 2):
print("Clause #%d W:%d P:%.2f R:%.2f " % (j, tm.get_weight(0, 1, j), precision[j], recall[j]), end=' ')
l = []
for k in range(tm.clause_banks[0].number_of_features * 2):
if tm.get_ta_action(j, k, the_class=0, polarity=1):
if k < tm.clause_banks[0].number_of_features:
l.append(" x%d(%d)" % (k, tm.get_ta_state(j, k, the_class=0, polarity=1)))
else:
l.append("¬x%d(%d)" % (k - tm.clause_banks[0].number_of_features, tm.get_ta_state(j, k, the_class=0, polarity=1)))
print(" ∧ ".join(l))

print("\nClass 1 Positive Clauses:\n")

precision = tm.clause_precision(1, 0, X_test, Y_test)
recall = tm.clause_recall(1, 0, X_test, Y_test)
experiment_results["class_1_precision_positive"].append(list(np.asarray(precision)))
experiment_results["class_1_recall_positive"].append(list(np.asarray(recall)))

for j in range(args.number_of_clauses // 2):
print("Clause #%d W:%d P:%.2f R:%.2f " % (j, tm.get_weight(1, 0, j), precision[j], recall[j]), end=' ')
l = []
for k in range(tm.clause_banks[0].number_of_features * 2):
if tm.get_ta_action(j, k, the_class=1, polarity=0):
if k < tm.clause_banks[0].number_of_features:
l.append(" x%d(%d)" % (k, tm.get_ta_state(j, k, the_class=1, polarity=0)))
else:
l.append("¬x%d(%d)" % (k - tm.clause_banks[0].number_of_features, tm.get_ta_state(j, k, the_class=1, polarity=0)))
print(" ∧ ".join(l))

print("\nClass 1 Negative Clauses:\n")

precision = tm.clause_precision(1, 1, X_test, Y_test)
recall = tm.clause_recall(1, 1, X_test, Y_test)
experiment_results["class_1_precision_negative"].append(list(np.asarray(precision)))
experiment_results["class_1_recall_negative"].append(list(np.asarray(recall)))

for j in range(args.number_of_clauses // 2):
print("Clause #%d W:%d P:%.2f R:%.2f " % (j, tm.get_weight(1, 1, j), precision[j], recall[j]), end=' ')
l = []
for k in range(tm.clause_banks[0].number_of_features * 2):
if tm.get_ta_action(j, k, the_class=1, polarity=1):
if k < tm.clause_banks[0].number_of_features:
l.append(" x%d(%d)" % (k, tm.get_ta_state(j, k, the_class=1, polarity=1)))
else:
l.append("¬x%d(%d)" % (k - tm.clause_banks[0].number_of_features, tm.get_ta_state(j, k, the_class=1, polarity=1)))
print(" ∧ ".join(l))

print("\nClause Co-Occurence Matrix:\n")
print(tm.clause_co_occurrence(X_test, percentage=True).toarray())

print("\nLiteral Frequency:\n")
print(tm.literal_clause_frequency())
experiment_results["literal_frequency"] = tm.literal_clause_frequency().tolist()


print(tm.clause_banks[0].number_of_features)
return experiment_results


def default_args(**kwargs):
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", default=2, type=int)
parser.add_argument("--number-of-clauses", default=10, type=int)
parser.add_argument("--platform", default='CPU', type=str)
parser.add_argument("--T", default=100, type=int)
parser.add_argument("--s", default=1.0, type=float)
parser.add_argument("--number-of-features", default=2, type=int)
parser.add_argument("--noise", default=0.0, type=float, help="Noisy XOR")
args = parser.parse_args()
for key, value in kwargs.items():
if key in args.__dict__:
setattr(args, key, value)
return args


if __name__ == "__main__":
results = main(default_args())
_LOGGER.info(results)
13 changes: 13 additions & 0 deletions tmu/lib/include/ClauseBank.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ void cb_type_i_feedback(
unsigned int *Xi
);

void cb_type_ii_feedback_recurrent(
unsigned int *ta_state,
unsigned int *output_one_patches,
int number_of_clauses,
int number_of_literals,
int number_of_state_bits,
int number_of_patches,
float update_p,
unsigned int *clause_active,
unsigned int *literal_active,
unsigned int *Xi
);

void cb_type_ii_feedback(
unsigned int *ta_state,
unsigned int *output_one_patches,
Expand Down
102 changes: 97 additions & 5 deletions tmu/lib/src/ClauseBank.c
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ static inline void cb_calculate_clause_output_feedback(unsigned int *ta_state, u
}
}


/* Calculate the output of each clause using the actions of each Tsetline Automaton. */
static inline int cb_calculate_clause_output_single_false_literal(unsigned int *ta_state, unsigned int *candidate_offending_literals, int number_of_ta_chunks, int number_of_state_bits, unsigned int filter, int number_of_patches, unsigned int *literal_active, unsigned int *Xi)
{
Expand Down Expand Up @@ -338,6 +339,97 @@ void cb_type_i_feedback(
}
}

void cb_type_ii_feeedback_clause_recurrent(
int number_of_clauses,
int clause,
int patch,
unsigned int *ta_state,
int number_of_ta_chunks,
int number_of_state_bits,
unsigned int *clause_active,
unsigned int *literal_active,
unsigned int filter,
unsigned int *Xi
)
{
unsigned int clause_pos = clause*number_of_ta_chunks*number_of_state_bits;

if (
cb_calculate_clause_output(
&ta_state[clause_pos],
number_of_ta_chunks,
number_of_state_bits,
filter,
literal_active,
&Xi[patch*number_of_ta_chunks]
)
)
{
if (clause_active[clause]) {
// Update clause with Type II Feedback
for (int k = 0; k < number_of_ta_chunks; ++k) {
unsigned int ta_pos = k*number_of_state_bits;
cb_inc(&ta_state[clause_pos + ta_pos], literal_active[k] & (~Xi[patch*number_of_ta_chunks + k]), number_of_state_bits);
}
}

if (patch > 1) {
// Proceed with included clauses from previous patch
for (int j = 0; j < number_of_clauses; ++j) {
unsigned int chunk_nr = j / 32;
unsigned int chunk_pos = j % 32;

if (ta_state[clause_pos + chunk_nr*number_of_state_bits + number_of_state_bits-1] & (1U << chunk_pos)) {
cb_type_ii_feeedback_clause_recurrent(number_of_clauses, j, patch-1, ta_state, number_of_ta_chunks, number_of_state_bits, clause_active, literal_active, filter, Xi);
}
}
}
}
}

void cb_type_ii_feedback_recurrent(
unsigned int *ta_state,
unsigned int *output_one_patches,
int number_of_clauses,
int number_of_literals,
int number_of_state_bits,
int number_of_patches,
float update_p,
unsigned int *clause_active,
unsigned int *literal_active,
unsigned int *Xi
)
{
unsigned int filter;
if (((number_of_literals) % 32) != 0) {
filter = (~(0xffffffff << ((number_of_literals) % 32)));
} else {
filter = 0xffffffff;
}
unsigned int number_of_ta_chunks = (number_of_literals-1)/32 + 1;

for (int j = 0; j < number_of_clauses; j++) {
if ((((float)fast_rand())/((float)FAST_RAND_MAX) > update_p) || (!clause_active[j])) {
continue;
}

unsigned int clause_pos = j*number_of_ta_chunks*number_of_state_bits;

cb_type_ii_feeedback_clause_recurrent(
number_of_clauses,
j,
number_of_patches-1,
ta_state,
number_of_ta_chunks,
number_of_state_bits,
clause_active,
literal_active,
filter,
Xi
);
}
}

void cb_type_ii_feedback(
unsigned int *ta_state,
unsigned int *output_one_patches,
Expand Down Expand Up @@ -741,14 +833,13 @@ void cb_calculate_clause_outputs_predict_recurrent(
&ta_state[clause_pos],
number_of_ta_chunks,
number_of_state_bits,
filter
)) && cb_calculate_clause_output_without_literal_active(
filter)
) && cb_calculate_clause_output_without_literal_active(
&ta_state[clause_pos],
number_of_ta_chunks,
number_of_state_bits,
filter,
&Xi[(number_of_patches-1)*number_of_ta_chunks]
);
&Xi[(number_of_patches-1)*number_of_ta_chunks]);
}
}

Expand Down Expand Up @@ -859,7 +950,8 @@ void cb_calculate_clause_features(
chunk_nr = (j + number_of_literals / 2) / 32;
chunk_pos = (j + number_of_literals / 2) % 32;

Xi[(patch + 1)*number_of_ta_chunks + chunk_nr] |= (1U << chunk_pos);
//Xi[(patch + 1)*number_of_ta_chunks + chunk_nr] |= (1U << chunk_pos);
Xi[(patch + 1)*number_of_ta_chunks + chunk_nr] &= ~(1U << chunk_pos);
}
}
}
Expand Down

0 comments on commit dec49d0

Please sign in to comment.