From 89bd5dace1cfca4a7e593afd3029e5a0b2d8e8bb Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 11:14:15 +0100 Subject: [PATCH 01/12] Merge Mayur1009:main --- .gitignore | 4 + .../classification/FashionMNISTMultioutput.py | 130 ++++++ .../classification/MNISTMultioutput.py | 83 ++++ .../classification/MNIST_mod_color.py | 164 +++++++ pyproject.toml | 5 +- tmu/__init__.py | 2 +- tmu/clause_bank/base_clause_bank.py | 6 + tmu/clause_bank/clause_bank.py | 5 +- tmu/clause_bank/clause_bank_cuda.py | 22 +- .../calculate_clause_outputs_patchwise.cu | 67 +++ .../models/multioutput_classifier.py | 426 ++++++++++++++++++ tmu/logging_example.json | 6 - tmu/models/base.py | 3 +- .../classification/vanilla_classifier.py | 5 +- 14 files changed, 908 insertions(+), 20 deletions(-) create mode 100644 examples/experimental/classification/FashionMNISTMultioutput.py create mode 100644 examples/experimental/classification/MNISTMultioutput.py create mode 100644 examples/experimental/classification/MNIST_mod_color.py create mode 100644 tmu/clause_bank/cuda/calculate_clause_outputs_patchwise.cu create mode 100644 tmu/experimental/models/multioutput_classifier.py diff --git a/.gitignore b/.gitignore index 07817eca..25fd7db2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +tmu/tmulib.c # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -25,6 +26,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +typings/ # PyInstaller # Usually these files are written by a python script from a template @@ -136,3 +138,5 @@ wheelhouse/ cmake-build-debug cmake-build-release + +.envrc diff --git a/examples/experimental/classification/FashionMNISTMultioutput.py b/examples/experimental/classification/FashionMNISTMultioutput.py new file mode 100644 index 00000000..b36c27bd --- /dev/null +++ b/examples/experimental/classification/FashionMNISTMultioutput.py @@ -0,0 +1,130 @@ +import argparse +from pprint import pprint + +import numpy as np +from sklearn.datasets import fetch_openml +from sklearn.metrics import accuracy_score, classification_report, f1_score, hamming_loss, precision_score, recall_score +from sklearn.model_selection import train_test_split +from tmu.experimental.models.multioutput_classifier import TMCoalesceMultiOuputClassifier + + +def dataset_fmnist_ch(ch=8): + X, y = fetch_openml( + "Fashion-MNIST", + version=1, + return_X_y=True, + as_frame=False, + ) + + xtrain_orig, xtest_orig, ytrain_orig, ytest_orig = train_test_split(X, y, random_state=0, test_size=10000) + ytrain_orig = np.array(ytrain_orig, dtype=int) + ytest_orig = np.array(ytest_orig, dtype=int) + + xtrain = np.array(xtrain_orig).reshape(-1, 28, 28) + xtest = np.array(xtest_orig).reshape(-1, 28, 28) + + out = np.zeros((*xtrain.shape, ch)) + for j in range(ch): + t1 = (j + 1) * 255 / (ch + 1) + t2 = (j + 2) * 255 / (ch + 1) + out[:, :, :, j] = np.logical_and(xtrain >= t1, xtrain < t2) & 1 + xtrain = np.array(out) + + out = np.zeros((*xtest.shape, ch)) + for j in range(ch): + t1 = (j + 1) * 255 / (ch + 1) + t2 = (j + 2) * 255 / (ch + 1) + out[:, :, :, j] = np.logical_and(xtest >= t1, xtest < t2) & 1 + xtest = np.array(out) + + ytrain = np.zeros((ytrain_orig.shape[0], 10), dtype=int) + for i in range(ytrain_orig.shape[0]): + ytrain[i, ytrain_orig[i]] = 1 + ytest = np.zeros((ytest_orig.shape[0], 10), dtype=int) + for i in range(ytest_orig.shape[0]): + ytest[i, ytest_orig[i]] = 1 + + label_names = [ + "tshirt", + "trouser", + "pullover", + "dress", + "coat", + "sandal", + "shirt", + "sneaker", + "bag", + "ankleboot", + ] + original = (xtrain_orig, ytrain_orig, xtest_orig, ytest_orig) + return original, xtrain, ytrain, xtest, ytest, label_names + + +def arr_div(a, b): + return np.divide(a, b, out=np.zeros_like(a, dtype=np.float32), where=b != 0) + + +def metrics(true, pred): + land = np.logical_and(true, pred) + lor = np.logical_or(true, pred) + lxor = np.logical_xor(true, pred) # symmetric diff + n_correct_labels = np.sum(land, axis=1) + total_active_labels = np.sum(lor, axis=1) + n_miss_preds = np.sum(lxor, axis=1) + + acc = np.mean(arr_div(n_correct_labels, total_active_labels)) + pre = np.mean(arr_div(n_correct_labels, np.sum(pred, axis=1))) + rec = np.mean(arr_div(n_correct_labels, np.sum(true, axis=1))) + f1s = 2 * pre * rec / (pre + rec) + hml = np.sum(n_miss_preds) / (true.shape[0] * true.shape[1]) + + return { + "Hamming loss": hml, + "Accuracy": acc, + "Precision": pre, + "Recall": rec, + "F1 score": f1s, + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--clauses", default=2000, type=int) + parser.add_argument("--T", default=3125, type=int) + parser.add_argument("--s", default=10.0, type=float) + parser.add_argument("--q", default=-1, type=float) + parser.add_argument("--type_ratio", default=1.0, type=float) + parser.add_argument("--platform", default="GPU", type=str) + parser.add_argument("--epochs", default=1, type=int) + parser.add_argument("--patch", default=10, type=int) + args = parser.parse_args() + + params = dict( + number_of_clauses=args.clauses, + T=args.T, + s=args.s, + q=args.q, + type_i_ii_ratio=args.type_ratio, + patch_dim=(args.patch, args.patch), + platform=args.platform, + seed=10, + ) + + original, xtrain, ytrain, xtest, ytest, label_names = dataset_fmnist_ch() + + tm = TMCoalesceMultiOuputClassifier(**params) + + print("Training with params: ") + pprint(params) + + for epoch in range(args.epochs): + print(f"Epoch {epoch}/{args.epochs}") + tm.fit(xtrain, ytrain, progress_bar=True) + pred = tm.predict(xtest, progress_bar=True) + + met = metrics(ytest, pred) + rep = classification_report(ytest, pred, target_names=label_names) + + pprint(met) + print(rep) + print("------------------------------") diff --git a/examples/experimental/classification/MNISTMultioutput.py b/examples/experimental/classification/MNISTMultioutput.py new file mode 100644 index 00000000..c84c0c87 --- /dev/null +++ b/examples/experimental/classification/MNISTMultioutput.py @@ -0,0 +1,83 @@ +import argparse +from pprint import pprint + +import numpy as np +from sklearn.metrics import accuracy_score, classification_report, f1_score, hamming_loss, precision_score, recall_score +from tmu.data import MNIST +from tmu.experimental.models.multioutput_classifier import TMCoalesceMultiOuputClassifier + + +def dataset_mnist(): + data = MNIST().get() + xtrain_orig, xtest_orig, ytrain_orig, ytest_orig = ( + data["x_train"], + data["x_test"], + data["y_train"], + data["y_test"], + ) + + xtrain = xtrain_orig.reshape(-1, 28, 28) + xtest = xtest_orig.reshape(-1, 28, 28) + + ytrain = np.zeros((ytrain_orig.shape[0], 10), dtype=int) + for i in range(ytrain_orig.shape[0]): + ytrain[i, ytrain_orig[i]] = 1 + ytest = np.zeros((ytest_orig.shape[0], 10), dtype=int) + for i in range(ytest_orig.shape[0]): + ytest[i, ytest_orig[i]] = 1 + + label_names = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] + original = (xtrain_orig, ytrain_orig, xtest_orig, ytest_orig) + return original, xtrain, ytrain, xtest, ytest, label_names + +def metrics(true, pred): + met = { + "Subset accuracy": accuracy_score(true, pred), + "Hamming loss": hamming_loss(true, pred), + "F1 score": f1_score(true, pred, average="weighted"), + "Precision": precision_score(true, pred, average="weighted"), + "Recall": recall_score(true, pred, average="weighted"), + } + return met + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--clauses", default=2000, type=int) + parser.add_argument("--T", default=3125, type=int) + parser.add_argument("--s", default=10.0, type=float) + parser.add_argument("--q", default=-1, type=float) + parser.add_argument("--type_ratio", default=1.0, type=float) + parser.add_argument("--platform", default="GPU", type=str) + parser.add_argument("--epochs", default=1, type=int) + parser.add_argument("--patch", default=10, type=int) + args = parser.parse_args() + + params = dict( + number_of_clauses=args.clauses, + T=args.T, + s=args.s, + q=args.q, + type_i_ii_ratio=args.type_ratio, + patch_dim=(args.patch, args.patch), + platform=args.platform, + seed=10, + ) + + original, xtrain, ytrain, xtest, ytest, label_names = dataset_mnist() + + tm = TMCoalesceMultiOuputClassifier(**params) + + print("Training with params: ") + pprint(params) + + for epoch in range(args.epochs): + print(f"Epoch {epoch}/{args.epochs}") + tm.fit(xtrain, ytrain, progress_bar=True) + pred = tm.predict(xtest, progress_bar=True) + + met = metrics(ytest, pred) + rep = classification_report(ytest, pred, target_names=label_names) + + pprint(met) + print(rep) + print("------------------------------") diff --git a/examples/experimental/classification/MNIST_mod_color.py b/examples/experimental/classification/MNIST_mod_color.py new file mode 100644 index 00000000..12312e44 --- /dev/null +++ b/examples/experimental/classification/MNIST_mod_color.py @@ -0,0 +1,164 @@ +import argparse +from math import sqrt +from pprint import pprint + +import numpy as np +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.metrics import classification_report +from tmu.data import MNIST +from tmu.experimental.models.multioutput_classifier import TMCoalesceMultiOuputClassifier +from tqdm import tqdm + +colors = { + "red": [1, 0, 0], + "green": [0, 1, 0], + "blue": [0, 0, 1], + "yellow": [1, 1, 0], + "cyan": [0, 1, 1], + "magenta": [1, 0, 1], + "white": [1, 1, 1], +} +num = { + 1: "one", + 2: "two", + 3: "three", + 4: "four", + 5: "five", + 6: "six", + 7: "seven", + 8: "eight", + 9: "nine", + 0: "zero", +} + + +def check_prime(n): + if n > 1: + is_prime = True + for i in range(2, int(sqrt(n)) + 1): + if n % i == 0: + is_prime = False + break + return is_prime + else: + return False + + +def add_color_mnist(x, y): + n = x.shape[0] + x = x.reshape(-1, 28, 28) + nt = np.concatenate([np.ones(n // 2), np.zeros(n - (n // 2))]) + np.random.shuffle(nt) + + x_color = np.stack([x] * 3, axis=-1) + + tx = [] + for i in tqdm(range(n)): + color_name = np.random.choice(list(colors.keys())) + label_text = f"{num[y[i]] if nt[i] else y[i]}" + is_prime = "prime" if check_prime(y[i]) else "" + odd_even = "odd" if y[i] & 1 else "even" + sentences = [ + f"{color_name} {label_text} {odd_even} {is_prime}", + f"The {label_text} {color_name} {odd_even} {is_prime}", + f"{color_name} image {label_text} {odd_even} {is_prime}", + ] + tx.append(np.random.choice(sentences)) + color = colors[color_name] + x_color[i, x[i, :, :] == 1, 0] = color[0] + x_color[i, x[i, :, :] == 1, 1] = color[1] + x_color[i, x[i, :, :] == 1, 2] = color[2] + + vectorizer = CountVectorizer(token_pattern=r"(?u)\b\w+\b") + y_color = vectorizer.fit_transform(tx).toarray() + + d = vectorizer.get_feature_names_out() + + return x_color, y_color, tx, d + + +def dataset_mnist_color(): + data = MNIST().get() + xtrain_orig, xtest_orig, ytrain_orig, ytest_orig = ( + data["x_train"], + data["x_test"], + data["y_train"], + data["y_test"], + ) + + xtrain, ytrain, txtrain, d1 = add_color_mnist(xtrain_orig, ytrain_orig) + xtest, ytest, txtest, d2 = add_color_mnist(xtest_orig, ytest_orig) + assert (d1 == d2).all(), f"d1 and d2 are not same. \n{d1}\n{d2}" + + original = (xtrain_orig, xtest_orig, ytrain_orig, ytest_orig) + return original, xtrain, ytrain, txtrain, xtest, ytest, txtest, d1 + + +def arr_div(a, b): + return np.divide(a, b, out=np.zeros_like(a, dtype=np.float32), where=b != 0) + + +def metrics(true, pred): + land = np.logical_and(true, pred) + lor = np.logical_or(true, pred) + lxor = np.logical_xor(true, pred) # symmetric diff + n_correct_labels = np.sum(land, axis=1) + total_active_labels = np.sum(lor, axis=1) + n_miss_preds = np.sum(lxor, axis=1) + + acc = np.mean(arr_div(n_correct_labels, total_active_labels)) + pre = np.mean(arr_div(n_correct_labels, np.sum(pred, axis=1))) + rec = np.mean(arr_div(n_correct_labels, np.sum(true, axis=1))) + f1s = 2 * pre * rec / (pre + rec) + hml = np.sum(n_miss_preds) / (true.shape[0] * true.shape[1]) + + return { + "Hamming loss": hml, + "Accuracy": acc, + "Precision": pre, + "Recall": rec, + "F1 score": f1s, + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--clauses", default=2000, type=int) + parser.add_argument("--T", default=3125, type=int) + parser.add_argument("--s", default=10.0, type=float) + parser.add_argument("--q", default=-1, type=float) + parser.add_argument("--type_ratio", default=1.0, type=float) + parser.add_argument("--platform", default="GPU", type=str) + parser.add_argument("--epochs", default=1, type=int) + parser.add_argument("--patch", default=10, type=int) + args = parser.parse_args() + + params = dict( + number_of_clauses=args.clauses, + T=args.T, + s=args.s, + q=args.q, + type_i_ii_ratio=args.type_ratio, + patch_dim=(args.patch, args.patch), + platform=args.platform, + seed=10, + ) + + original, xtrain, ytrain, txtrain, xtest, ytest, txtest, label_names = dataset_mnist_color() + + tm = TMCoalesceMultiOuputClassifier(**params) + + print("Training with params: ") + pprint(params) + + for epoch in range(args.epochs): + print(f"Epoch {epoch}/{args.epochs}") + tm.fit(xtrain, ytrain, progress_bar=True) + pred = tm.predict(xtest, progress_bar=True) + + met = metrics(ytest, pred) + rep = classification_report(ytest, pred, target_names=label_names) + + pprint(met) + print(rep) + print("------------------------------") diff --git a/pyproject.toml b/pyproject.toml index c920d9b8..73ecf6b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,9 @@ tests = [ homepage = "https://github.com/cair/tmu/" repository = "https://github.com/cair/tmu/" +[tool.basedpyright] +reportMissingTypeStubs = "information" +typeCheckingMode = "standard" [tool.setuptools.package-dir] tmu = "tmu" @@ -83,8 +86,6 @@ headers = [ include_dir = "tmu/lib/include" - - [tool.cibuildwheel] build = "*" test-skip = "" diff --git a/tmu/__init__.py b/tmu/__init__.py index 63a4d4b0..65e51ac5 100644 --- a/tmu/__init__.py +++ b/tmu/__init__.py @@ -21,4 +21,4 @@ except ImportError as e: raise ImportError("Could not import cffi compiled libraries. To fix this problem, run pip install -e .", e) -__version__ = "0.8.3" +__version__ = "0.8.3.58" diff --git a/tmu/clause_bank/base_clause_bank.py b/tmu/clause_bank/base_clause_bank.py index f980234e..42ad8e61 100644 --- a/tmu/clause_bank/base_clause_bank.py +++ b/tmu/clause_bank/base_clause_bank.py @@ -91,3 +91,9 @@ def get_ta_action(self, clause, ta): pos = int( clause * self.number_of_ta_chunks * self.number_of_state_bits_ta + ta_chunk * self.number_of_state_bits_ta + self.number_of_state_bits_ta - 1) return (self.clause_bank[pos] & (1 << chunk_pos)) > 0 + + def get_literals(self): + nc, nl = self.number_of_clauses, self.number_of_literals + results = np.array([[self.get_ta_action(i, j) for j in range(nl)] for i in range(nc)]) + return results.astype(np.int8) + diff --git a/tmu/clause_bank/clause_bank.py b/tmu/clause_bank/clause_bank.py index b3c014be..fb88bf3c 100644 --- a/tmu/clause_bank/clause_bank.py +++ b/tmu/clause_bank/clause_bank.py @@ -133,7 +133,8 @@ def initialize_clauses(self): order="c" ) - self.clause_bank[:, :, 0: self.number_of_state_bits_ta - 1] = np.uint32(~0) + # TODO: MAKE issue for newer numpy version for np.uint32(-1) + self.clause_bank[:, :, 0: self.number_of_state_bits_ta - 1] = np.array(~0).astype(np.uint32) self.clause_bank[:, :, self.number_of_state_bits_ta - 1] = 0 self.clause_bank = np.ascontiguousarray(self.clause_bank.reshape( (self.number_of_clauses * self.number_of_ta_chunks * self.number_of_state_bits_ta))) @@ -142,7 +143,7 @@ def initialize_clauses(self): self.clause_bank_ind = np.empty( (self.number_of_clauses, self.number_of_ta_chunks, self.number_of_state_bits_ind), dtype=np.uint32) - self.clause_bank_ind[:, :, :] = np.uint32(~0) + self.clause_bank_ind[:, :, :] = np.array(~0).astype(np.uint32) self.clause_bank_ind = np.ascontiguousarray(self.clause_bank_ind.reshape( (self.number_of_clauses * self.number_of_ta_chunks * self.number_of_state_bits_ind))) diff --git a/tmu/clause_bank/clause_bank_cuda.py b/tmu/clause_bank/clause_bank_cuda.py index ac59093c..c9953011 100644 --- a/tmu/clause_bank/clause_bank_cuda.py +++ b/tmu/clause_bank/clause_bank_cuda.py @@ -119,6 +119,10 @@ def __init__( self.calculate_clause_outputs_update_gpu = mod.get_function("calculate_clause_outputs_update") self.calculate_clause_outputs_update_gpu.prepare("PiiiPPPi") + mod = load_cuda_kernel(parameters, "cuda/calculate_clause_outputs_patchwise.cu") + self.calculate_clause_outputs_patchwise_gpu = mod.get_function("calculate_clause_outputs_patchwise") + self.calculate_clause_outputs_patchwise_gpu.prepare("PiiiPPi") + mod = load_cuda_kernel(parameters, "cuda/clause_feedback.cu") self.type_i_feedback_gpu = mod.get_function("type_i_feedback") self.type_i_feedback_gpu.prepare("PPiiiffiiPPPi") @@ -147,6 +151,7 @@ def __init__( dtype=np.uint32, order="c" ) + self.clause_output_patchwise_gpu = self._profiler.profile(cuda.mem_alloc, self.clause_output_patchwise.nbytes) self.clause_active_gpu = self._profiler.profile(cuda.mem_alloc, self.clause_output.nbytes) self.literal_active_gpu = self._profiler.profile(cuda.mem_alloc, self.number_of_ta_chunks * 4) @@ -219,18 +224,23 @@ def calculate_clause_outputs_update(self, literal_active, encoded_X, e): return self.clause_output def calculate_clause_outputs_patchwise(self, encoded_X, e): - xi_p = ffi.cast("unsigned int *", Xi.ctypes.data) - lib.cb_calculate_clause_outputs_patchwise( - self.cb_p, + + self.calculate_clause_outputs_patchwise_gpu.prepared_call( + self.grid, + self.block, + self.clause_bank_gpu, self.number_of_clauses, self.number_of_literals, self.number_of_state_bits_ta, - self.number_of_patches, - self.cop_p, - xi_p + self.clause_output_patchwise_gpu, + encoded_X, + np.int32(e) ) + self.cuda_ctx.synchronize() + self._profiler.profile(cuda.memcpy_dtoh, self.clause_output_patchwise, self.clause_output_patchwise_gpu) return self.clause_output_patchwise + def type_i_feedback( self, update_p, diff --git a/tmu/clause_bank/cuda/calculate_clause_outputs_patchwise.cu b/tmu/clause_bank/cuda/calculate_clause_outputs_patchwise.cu new file mode 100644 index 00000000..d3dd2ca9 --- /dev/null +++ b/tmu/clause_bank/cuda/calculate_clause_outputs_patchwise.cu @@ -0,0 +1,67 @@ +/*** +# Copyright (c) 2021 Ole-Christoffer Granmo + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# This code implements the Convolutional Tsetlin Machine from paper arXiv:1905.09688 +# https://arxiv.org/abs/1905.09688 +***/ +#include + +extern "C" +{ + __device__ inline void calculate_clause_output_patchwise(unsigned int *ta_state, int number_of_ta_chunks, int number_of_state_bits, unsigned int filter, unsigned int *output, unsigned int *Xi) + { + for (int patch = 0; patch < NUMBER_OF_PATCHES; ++patch) { + output[patch] = 1; + for (int k = 0; k < number_of_ta_chunks-1; k++) { + unsigned int pos = k*number_of_state_bits + number_of_state_bits-1; + output[patch] = output[patch] && (ta_state[pos] & Xi[patch*number_of_ta_chunks + k]) == ta_state[pos]; + + if (!output[patch]) { + break; + } + } + + unsigned int pos = (number_of_ta_chunks-1)*number_of_state_bits + number_of_state_bits-1; + output[patch] = output[patch] && + (ta_state[pos] & Xi[patch*number_of_ta_chunks + number_of_ta_chunks - 1] & filter) == + (ta_state[pos] & filter); + } + } + + __global__ void calculate_clause_outputs_patchwise(unsigned int *ta_state, int number_of_clauses, int number_of_literals, int number_of_state_bits, unsigned int *clause_output, unsigned int *X, int e) + { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + unsigned int filter; + if (((number_of_literals) % 32) != 0) { + filter = (~(0xffffffff << ((number_of_literals) % 32))); + } else { + filter = 0xffffffff; + } + unsigned int number_of_ta_chunks = (number_of_literals-1)/32 + 1; + + for (int j = index; j < number_of_clauses; j += stride) { + unsigned int clause_pos = j*number_of_ta_chunks*number_of_state_bits; + calculate_clause_output_patchwise(&ta_state[clause_pos], number_of_ta_chunks, number_of_state_bits, filter, &clause_output[j*NUMBER_OF_PATCHES], &X[e*(number_of_ta_chunks*NUMBER_OF_PATCHES)]); + } + } +} diff --git a/tmu/experimental/models/multioutput_classifier.py b/tmu/experimental/models/multioutput_classifier.py new file mode 100644 index 00000000..20f9161f --- /dev/null +++ b/tmu/experimental/models/multioutput_classifier.py @@ -0,0 +1,426 @@ +# Copyright (c) 2023 Ole-Christoffer Granmo +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import numpy as np + +# import pandas as pd +from tqdm import tqdm + +from tmu.models.base import MultiWeightBankMixin, SingleClauseBankMixin, TMBaseModel +from tmu.util.encoded_data_cache import DataEncoderCache +from tmu.weight_bank import WeightBank + + +class TMCoalesceMultiOuputClassifier( + TMBaseModel, SingleClauseBankMixin, MultiWeightBankMixin +): + def __init__( + self, + number_of_clauses, + T, + s, + type_i_ii_ratio=1.0, + type_iii_feedback=False, + focused_negative_sampling=False, + output_balancing=False, + d=200.0, + platform="CPU", + patch_dim=None, + feature_negation=True, + boost_true_positive_feedback=1, + reuse_random_feedback=0, + max_positive_clauses=None, + max_included_literals=None, + number_of_state_bits_ta=8, + number_of_state_bits_ind=8, + weighted_clauses=False, + clause_drop_p=0.0, + literal_drop_p=0.0, + q=-1, + seed=None, + ): + super().__init__( + number_of_clauses=number_of_clauses, + T=T, + s=s, + type_i_ii_ratio=type_i_ii_ratio, + type_iii_feedback=type_iii_feedback, + focused_negative_sampling=focused_negative_sampling, + output_balancing=output_balancing, + d=d, + platform=platform, + patch_dim=patch_dim, + feature_negation=feature_negation, + boost_true_positive_feedback=boost_true_positive_feedback, + reuse_random_feedback=reuse_random_feedback, + max_included_literals=max_included_literals, + number_of_state_bits_ta=number_of_state_bits_ta, + number_of_state_bits_ind=number_of_state_bits_ind, + weighted_clauses=weighted_clauses, + clause_drop_p=clause_drop_p, + literal_drop_p=literal_drop_p, + seed=seed, + ) + SingleClauseBankMixin.__init__(self) + MultiWeightBankMixin.__init__(self, seed=seed) + + # These data structures cache the encoded data for the training and test sets. It also makes a fast-check if + # training data has changed, and only re-encodes if it has. + self.test_encoder_cache = DataEncoderCache(seed=self.seed) + self.train_encoder_cache = DataEncoderCache(seed=self.seed) + + self.max_positive_clauses = max_positive_clauses + self.q = q + + def init_clause_bank(self, X: np.ndarray, Y: np.ndarray): + clause_bank_type, clause_bank_args = self.build_clause_bank(X=X) + self.clause_bank = clause_bank_type(**clause_bank_args) + self.X_shape = X.shape + + def init_weight_bank(self, X: np.ndarray, Y: np.ndarray): + self.number_of_classes = Y.shape[1] + if self.q < 0: + self.q = max(1, self.number_of_classes - 1) / 2 + self.weight_banks.set_clause_init( + WeightBank, + dict( + weights=self.rng.choice([-1, 1], size=self.number_of_clauses).astype( + np.int32 + ) + ), + ) + self.weight_banks.populate(list(range(self.number_of_classes))) + + def init_after(self, X: np.ndarray, Y: np.ndarray): + if self.max_included_literals is None: + self.max_included_literals = self.clause_bank.number_of_literals + + if self.max_positive_clauses is None: + self.max_positive_clauses = self.number_of_clauses + + def fit(self, X, Y, shuffle=True, progress_bar=False, met=False, **kwargs): + self.init(X, Y) + + encoded_X_train = self.train_encoder_cache.get_encoded_data( + X, encoder_func=lambda x: self.clause_bank.prepare_X(X) + ) + + # Drops clauses randomly based on clause drop probability + self.clause_active = ( + self.rng.rand(self.number_of_clauses) >= self.clause_drop_p + ).astype(np.int32) + + # Literals are dropped based on literal drop probability + self.literal_active = np.zeros( + self.clause_bank.number_of_ta_chunks, dtype=np.uint32 + ) + literal_active_integer = ( + self.rng.rand(self.clause_bank.number_of_literals) >= self.literal_drop_p + ) + for k in range(self.clause_bank.number_of_literals): + if literal_active_integer[k] == 1: + ta_chunk = k // 32 + chunk_pos = k % 32 + self.literal_active[ta_chunk] |= 1 << chunk_pos + + if not self.feature_negation: + for k in range( + self.clause_bank.number_of_literals // 2, + self.clause_bank.number_of_literals, + ): + ta_chunk = k // 32 + chunk_pos = k % 32 + self.literal_active[ta_chunk] &= ~(1 << chunk_pos) + + self.literal_active = self.literal_active.astype(np.uint32) + + shuffled_index = np.arange(X.shape[0]) + if shuffle: + self.rng.shuffle(shuffled_index) + + pbar = tqdm(shuffled_index) if progress_bar else shuffled_index + + # Combine all weight banks, to make use of faster numpy matrix operation + self.wcomb = np.empty( + (self.number_of_clauses, self.number_of_classes), dtype=np.int32 + ) + for c in range(self.number_of_classes): + self.wcomb[:, c] = self.weight_banks[c].get_weights() + + # Find all 0 and 1 indices in Y + pos_class_ind = [np.where(i == 1)[0] for i in Y] + neg_class_ind = [np.where(i == 0)[0] for i in Y] + + self.pf = np.zeros(self.number_of_classes) + self.nf = np.zeros(self.number_of_classes) + self.class_sums_per_sample = np.empty((X.shape[0], self.number_of_classes)) + self.update_p_per_sample = np.empty((X.shape[0], self.number_of_classes)) + # self.avg_n_neg_classes = 0 + + for e in pbar: + clause_outputs = self.clause_bank.calculate_clause_outputs_update( + self.literal_active, encoded_X_train, e + ) + class_sums = (clause_outputs * self.clause_active)[ + np.newaxis, : + ] @ self.wcomb + class_sums = np.clip(class_sums, -self.T, self.T).astype(np.int32).ravel() + self.class_sums_per_sample[e, :] = class_sums + + pos_ind = pos_class_ind[e] + neg_ind = neg_class_ind[e] + t = self.T * np.ones(self.number_of_classes) + t[neg_ind] *= -1 + self.update_ps = (t - class_sums) / (2 * t) + + self.update_p_per_sample[e, :] = self.update_ps + + for c in pos_ind: + update_p = self.update_ps[c] + self.clause_bank.type_i_feedback( + update_p=update_p * self.type_i_p, + clause_active=self.clause_active + * (self.weight_banks[c].get_weights() >= 0), + literal_active=self.literal_active, + encoded_X=encoded_X_train, + e=e, + ) + self.clause_bank.type_ii_feedback( + update_p=update_p * self.type_ii_p, + clause_active=self.clause_active + * (self.weight_banks[c].get_weights() < 0), + literal_active=self.literal_active, + encoded_X=encoded_X_train, + e=e, + ) + if ( + self.weight_banks[c].get_weights() >= 0 + ).sum() < self.max_positive_clauses: + self.weight_banks[c].increment( + clause_output=clause_outputs, + update_p=update_p, + clause_active=self.clause_active, + positive_weights=True, + ) + self.wcomb[:, c] = self.weight_banks[c].get_weights() + + self.update_ps[c] = 0.0 + self.pf[c] += 1 + + if np.sum(self.update_ps) == 0: + continue + + rand_smp = self.rng.random_sample(self.number_of_classes) + self.rng.shuffle(neg_ind) + + for c in neg_ind: + if rand_smp[c] <= (self.q / max(1, self.number_of_classes - 1)): + update_p = self.update_ps[c] + self.clause_bank.type_i_feedback( + update_p=update_p * self.type_i_p, + clause_active=self.clause_active + * (self.weight_banks[c].get_weights() < 0), + literal_active=self.literal_active, + encoded_X=encoded_X_train, + e=e, + ) + + self.clause_bank.type_ii_feedback( + update_p=update_p * self.type_ii_p, + clause_active=self.clause_active + * (self.weight_banks[c].get_weights() >= 0), + literal_active=self.literal_active, + encoded_X=encoded_X_train, + e=e, + ) + + self.weight_banks[c].decrement( + clause_output=clause_outputs, + update_p=update_p, + clause_active=self.clause_active, + negative_weights=True, + ) + self.wcomb[:, c] = self.weight_banks[c].get_weights() + self.update_ps[c] = 0.0 + self.nf[c] += 1 + # self.avg_n_neg_classes += 1 + # print( + # f"Average num of neg classes selected = {self.avg_n_neg_classes / Y.shape[0]}" + # ) + # print( + # pd.DataFrame( + # { + # "n_pos": self.pf, + # "n_neg": self.nf, + # "R": self.pf / self.nf, + # "n_lab": Y.sum(axis=0), + # "n_lab_e": Y.sum(axis=0) / Y.shape[0], + # "n_nlb": Y.shape[0] - Y.sum(axis=0), + # "n_nlb_e": (Y.shape[0] - Y.sum(axis=0)) / Y.shape[0], + # } + # ) + # ) + if met: + return { + "pf": self.pf, + "nf": self.nf, + "class_sums": self.class_sums_per_sample, + "update_p": self.update_p_per_sample, + } + + def predict( + self, + X, + shuffle=False, + clip_class_sum=False, + return_class_sums: bool = False, + progress_bar=False, + **kwargs, + ) -> tuple[np.ndarray, np.ndarray] | np.ndarray: + encoded_X_test = self.clause_bank.prepare_X(X) + + for c in range(self.number_of_classes): + self.wcomb[:, c] = self.weight_banks[c].get_weights() + + shuffled_index = np.arange(X.shape[0]) + if shuffle: + self.rng.shuffle(shuffled_index) + pbar = tqdm(shuffled_index) if progress_bar else shuffled_index + + # Compute class sums for all samples + class_sums = np.empty((X.shape[0], self.number_of_classes)) + for e in pbar: + class_sums[e, :] = self.compute_class_sums( + encoded_X_test, e, clip_class_sum + ) + + output = (class_sums >= 0).astype(np.uint32) + + if return_class_sums: + return output, class_sums + else: + return output + + def compute_class_sums(self, encoded_X_test, ith_sample: int, clip_class_sum: bool): + """The following function evaluates the resulting class sum votes. + + Args: + ith_sample (int): The index of the sample + clip_class_sum (bool): Wether to clip class sums + + Returns: + list[int]: list of all class sums + """ + clause_outputs = self.clause_bank.calculate_clause_outputs_predict( + encoded_X_test, ith_sample + ) + class_sums = clause_outputs[np.newaxis, :] @ self.wcomb + if clip_class_sum: + class_sums = np.clip(class_sums, -self.T, self.T).astype(np.int32) + return class_sums + + def to_cpu(self): + if self.platform in ["GPU", "CUDA"]: + arr = np.empty((self.X_shape)) + clause_bank_gpu = self.clause_bank + clause_bank_gpu.synchronize_clause_bank() + clause_bank_type, clause_bank_args = self._build_cpu_bank(arr) + clause_bank_cpu = clause_bank_type(**clause_bank_args) + + clause_bank_cpu.clause_bank = clause_bank_gpu.clause_bank + clause_bank_cpu.clause_output = clause_bank_gpu.clause_output + clause_bank_cpu.literal_clause_count = clause_bank_gpu.literal_clause_count + + clause_bank_cpu._cffi_init() + + self.clause_bank = clause_bank_cpu + self.platform = "CPU" + print("to_cpu(): Successful....") + + elif self.platform == "CPU": + print("to_cpu(): Already CPU....") + + else: + print("to_cpu(): Not implemented....") + + def clause_precision(self, the_class, X, Y): + clause_outputs = self.transform(X) + weights = self.weight_banks[the_class].get_weights() + + positive_clause_outputs = (weights >= 0)[ + :, np.newaxis + ].transpose() * clause_outputs + true_positive_clause_outputs = positive_clause_outputs[Y == the_class].sum( + axis=0 + ) + false_positive_clause_outputs = positive_clause_outputs[Y != the_class].sum( + axis=0 + ) + + positive_clause_outputs = (weights < 0)[ + :, np.newaxis + ].transpose() * clause_outputs + true_positive_clause_outputs += positive_clause_outputs[Y != the_class].sum( + axis=0 + ) + false_positive_clause_outputs += positive_clause_outputs[Y == the_class].sum( + axis=0 + ) + + return np.where( + true_positive_clause_outputs + false_positive_clause_outputs == 0, + 0, + 1.0 + * true_positive_clause_outputs + / (true_positive_clause_outputs + false_positive_clause_outputs), + ) + + def clause_recall(self, the_class, X, Y): + clause_outputs = self.transform(X) + weights = self.weight_banks[the_class].get_weights() + + positive_clause_outputs = (weights >= 0)[ + :, np.newaxis + ].transpose() * clause_outputs + true_positive_clause_outputs = ( + positive_clause_outputs[Y == the_class].sum(axis=0) + / Y[Y == the_class].shape[0] + ) + + positive_clause_outputs = (weights < 0)[ + :, np.newaxis + ].transpose() * clause_outputs + true_positive_clause_outputs += ( + positive_clause_outputs[Y != the_class].sum(axis=0) + / Y[Y != the_class].shape[0] + ) + + return true_positive_clause_outputs + + def get_weights(self, the_class): + return self.weight_banks[the_class].get_weights() + + def get_weight(self, the_class, clause): + return self.weight_banks[the_class].get_weights()[clause] + + def set_weight(self, the_class, clause, weight): + self.weight_banks[the_class].get_weights()[clause] = weight + + def number_of_include_actions(self, clause): + return self.clause_bank.number_of_include_actions(clause) diff --git a/tmu/logging_example.json b/tmu/logging_example.json index 6d7c4afc..44eab162 100644 --- a/tmu/logging_example.json +++ b/tmu/logging_example.json @@ -21,11 +21,5 @@ ], "propagate": false } - }, - "root": { - "level": "DEBUG", - "handlers": [ - "console" - ] } } diff --git a/tmu/models/base.py b/tmu/models/base.py index 798bbbae..84dd9448 100644 --- a/tmu/models/base.py +++ b/tmu/models/base.py @@ -189,7 +189,7 @@ def set_ta_state(self, clause, ta, state, **kwargs): def fit(self, X, Y, *args, **kwargs): raise NotImplementedError("fit(self, X, Y, *args, **kwargs) is not implemented for your model") - def predict(self, X, shuffle=True) -> np.ndarray: + def predict(self, X, shuffle=True): raise NotImplementedError("predict(self, X: np.ndarray") def init_clause_bank(self, X: np.ndarray, Y: np.ndarray): @@ -244,6 +244,7 @@ def _build_gpu_bank(self, X: np.ndarray): if not cuda_installed: _LOGGER.warning("CUDA not installed, using CPU clause bank") + self.platform = "CPU" return self._build_cpu_bank(X=X) clause_bank_type = ClauseBankCUDA diff --git a/tmu/models/classification/vanilla_classifier.py b/tmu/models/classification/vanilla_classifier.py index 5fc5fa4e..685432f7 100644 --- a/tmu/models/classification/vanilla_classifier.py +++ b/tmu/models/classification/vanilla_classifier.py @@ -24,6 +24,7 @@ from tmu.weight_bank import WeightBank import numpy as np import logging +from tqdm import tqdm _LOGGER = logging.getLogger(__name__) @@ -396,7 +397,7 @@ def fit( if shuffle: self.rng.shuffle(sample_indices) - for sample_idx in sample_indices: + for sample_idx in tqdm( sample_indices ): target: int = Y[sample_idx] not_target: int | None = self.weight_banks.sample(exclude=[target]) @@ -436,7 +437,7 @@ def predict( encoded_X_test=encoded_X_test, ith_sample=i, clip_class_sum=clip_class_sum - ) for i in range(X.shape[0]) + ) for i in tqdm( range(X.shape[0]) ) ]) max_classes = np.argmax(class_sums, axis=1) From 1c9fcb3b8af44f3e7ec0cf3761e76a5401c2d299 Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 11:17:06 +0100 Subject: [PATCH 02/12] Changed .gitignore --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index 25fd7db2..eaaafbbb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -tmu/tmulib.c # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -26,7 +25,6 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST -typings/ # PyInstaller # Usually these files are written by a python script from a template From d0d06dafe7214ed3571d428cb4f32e7ec02fe78a Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 11:19:41 +0100 Subject: [PATCH 03/12] Changed pyproject.toml --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 73ecf6b2..db154fe2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ homepage = "https://github.com/cair/tmu/" repository = "https://github.com/cair/tmu/" [tool.basedpyright] -reportMissingTypeStubs = "information" typeCheckingMode = "standard" [tool.setuptools.package-dir] @@ -86,6 +85,8 @@ headers = [ include_dir = "tmu/lib/include" + + [tool.cibuildwheel] build = "*" test-skip = "" From fcfb565a375666ae31da3c5bdfea6f73a2540ba9 Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 11:48:44 +0100 Subject: [PATCH 04/12] Changed __init__ --- tmu/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tmu/__init__.py b/tmu/__init__.py index 65e51ac5..63a4d4b0 100644 --- a/tmu/__init__.py +++ b/tmu/__init__.py @@ -21,4 +21,4 @@ except ImportError as e: raise ImportError("Could not import cffi compiled libraries. To fix this problem, run pip install -e .", e) -__version__ = "0.8.3.58" +__version__ = "0.8.3" From 3a57b2dc1432daa2725892b3b10869641c13adf0 Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 11:58:06 +0100 Subject: [PATCH 05/12] Revert changes to vanilla_classifier --- tmu/models/classification/vanilla_classifier.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tmu/models/classification/vanilla_classifier.py b/tmu/models/classification/vanilla_classifier.py index 685432f7..5fc5fa4e 100644 --- a/tmu/models/classification/vanilla_classifier.py +++ b/tmu/models/classification/vanilla_classifier.py @@ -24,7 +24,6 @@ from tmu.weight_bank import WeightBank import numpy as np import logging -from tqdm import tqdm _LOGGER = logging.getLogger(__name__) @@ -397,7 +396,7 @@ def fit( if shuffle: self.rng.shuffle(sample_indices) - for sample_idx in tqdm( sample_indices ): + for sample_idx in sample_indices: target: int = Y[sample_idx] not_target: int | None = self.weight_banks.sample(exclude=[target]) @@ -437,7 +436,7 @@ def predict( encoded_X_test=encoded_X_test, ith_sample=i, clip_class_sum=clip_class_sum - ) for i in tqdm( range(X.shape[0]) ) + ) for i in range(X.shape[0]) ]) max_classes = np.argmax(class_sums, axis=1) From 047f0a8b0a7ee3796ac1d090d342261eb4e0125b Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 12:21:36 +0100 Subject: [PATCH 06/12] Revert changes in logging_example --- tmu/logging_example.json | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tmu/logging_example.json b/tmu/logging_example.json index 44eab162..c3314cbb 100644 --- a/tmu/logging_example.json +++ b/tmu/logging_example.json @@ -20,6 +20,12 @@ "console" ], "propagate": false + }, + "root": { + "level": "DEBUG", + "handlers": [ + "console" + ] } } } From 25d937859983465cc6c3993d2988c11c792720b1 Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 12:22:58 +0100 Subject: [PATCH 07/12] Changes in models/base.py --- tmu/models/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tmu/models/base.py b/tmu/models/base.py index 84dd9448..798bbbae 100644 --- a/tmu/models/base.py +++ b/tmu/models/base.py @@ -189,7 +189,7 @@ def set_ta_state(self, clause, ta, state, **kwargs): def fit(self, X, Y, *args, **kwargs): raise NotImplementedError("fit(self, X, Y, *args, **kwargs) is not implemented for your model") - def predict(self, X, shuffle=True): + def predict(self, X, shuffle=True) -> np.ndarray: raise NotImplementedError("predict(self, X: np.ndarray") def init_clause_bank(self, X: np.ndarray, Y: np.ndarray): @@ -244,7 +244,6 @@ def _build_gpu_bank(self, X: np.ndarray): if not cuda_installed: _LOGGER.warning("CUDA not installed, using CPU clause bank") - self.platform = "CPU" return self._build_cpu_bank(X=X) clause_bank_type = ClauseBankCUDA From 757d2e14c697e5953cb080cb97ba3af237ae091a Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 12:24:16 +0100 Subject: [PATCH 08/12] Remove a multioutput example --- .../classification/MNIST_mod_color.py | 164 ------------------ 1 file changed, 164 deletions(-) delete mode 100644 examples/experimental/classification/MNIST_mod_color.py diff --git a/examples/experimental/classification/MNIST_mod_color.py b/examples/experimental/classification/MNIST_mod_color.py deleted file mode 100644 index 12312e44..00000000 --- a/examples/experimental/classification/MNIST_mod_color.py +++ /dev/null @@ -1,164 +0,0 @@ -import argparse -from math import sqrt -from pprint import pprint - -import numpy as np -from sklearn.feature_extraction.text import CountVectorizer -from sklearn.metrics import classification_report -from tmu.data import MNIST -from tmu.experimental.models.multioutput_classifier import TMCoalesceMultiOuputClassifier -from tqdm import tqdm - -colors = { - "red": [1, 0, 0], - "green": [0, 1, 0], - "blue": [0, 0, 1], - "yellow": [1, 1, 0], - "cyan": [0, 1, 1], - "magenta": [1, 0, 1], - "white": [1, 1, 1], -} -num = { - 1: "one", - 2: "two", - 3: "three", - 4: "four", - 5: "five", - 6: "six", - 7: "seven", - 8: "eight", - 9: "nine", - 0: "zero", -} - - -def check_prime(n): - if n > 1: - is_prime = True - for i in range(2, int(sqrt(n)) + 1): - if n % i == 0: - is_prime = False - break - return is_prime - else: - return False - - -def add_color_mnist(x, y): - n = x.shape[0] - x = x.reshape(-1, 28, 28) - nt = np.concatenate([np.ones(n // 2), np.zeros(n - (n // 2))]) - np.random.shuffle(nt) - - x_color = np.stack([x] * 3, axis=-1) - - tx = [] - for i in tqdm(range(n)): - color_name = np.random.choice(list(colors.keys())) - label_text = f"{num[y[i]] if nt[i] else y[i]}" - is_prime = "prime" if check_prime(y[i]) else "" - odd_even = "odd" if y[i] & 1 else "even" - sentences = [ - f"{color_name} {label_text} {odd_even} {is_prime}", - f"The {label_text} {color_name} {odd_even} {is_prime}", - f"{color_name} image {label_text} {odd_even} {is_prime}", - ] - tx.append(np.random.choice(sentences)) - color = colors[color_name] - x_color[i, x[i, :, :] == 1, 0] = color[0] - x_color[i, x[i, :, :] == 1, 1] = color[1] - x_color[i, x[i, :, :] == 1, 2] = color[2] - - vectorizer = CountVectorizer(token_pattern=r"(?u)\b\w+\b") - y_color = vectorizer.fit_transform(tx).toarray() - - d = vectorizer.get_feature_names_out() - - return x_color, y_color, tx, d - - -def dataset_mnist_color(): - data = MNIST().get() - xtrain_orig, xtest_orig, ytrain_orig, ytest_orig = ( - data["x_train"], - data["x_test"], - data["y_train"], - data["y_test"], - ) - - xtrain, ytrain, txtrain, d1 = add_color_mnist(xtrain_orig, ytrain_orig) - xtest, ytest, txtest, d2 = add_color_mnist(xtest_orig, ytest_orig) - assert (d1 == d2).all(), f"d1 and d2 are not same. \n{d1}\n{d2}" - - original = (xtrain_orig, xtest_orig, ytrain_orig, ytest_orig) - return original, xtrain, ytrain, txtrain, xtest, ytest, txtest, d1 - - -def arr_div(a, b): - return np.divide(a, b, out=np.zeros_like(a, dtype=np.float32), where=b != 0) - - -def metrics(true, pred): - land = np.logical_and(true, pred) - lor = np.logical_or(true, pred) - lxor = np.logical_xor(true, pred) # symmetric diff - n_correct_labels = np.sum(land, axis=1) - total_active_labels = np.sum(lor, axis=1) - n_miss_preds = np.sum(lxor, axis=1) - - acc = np.mean(arr_div(n_correct_labels, total_active_labels)) - pre = np.mean(arr_div(n_correct_labels, np.sum(pred, axis=1))) - rec = np.mean(arr_div(n_correct_labels, np.sum(true, axis=1))) - f1s = 2 * pre * rec / (pre + rec) - hml = np.sum(n_miss_preds) / (true.shape[0] * true.shape[1]) - - return { - "Hamming loss": hml, - "Accuracy": acc, - "Precision": pre, - "Recall": rec, - "F1 score": f1s, - } - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--clauses", default=2000, type=int) - parser.add_argument("--T", default=3125, type=int) - parser.add_argument("--s", default=10.0, type=float) - parser.add_argument("--q", default=-1, type=float) - parser.add_argument("--type_ratio", default=1.0, type=float) - parser.add_argument("--platform", default="GPU", type=str) - parser.add_argument("--epochs", default=1, type=int) - parser.add_argument("--patch", default=10, type=int) - args = parser.parse_args() - - params = dict( - number_of_clauses=args.clauses, - T=args.T, - s=args.s, - q=args.q, - type_i_ii_ratio=args.type_ratio, - patch_dim=(args.patch, args.patch), - platform=args.platform, - seed=10, - ) - - original, xtrain, ytrain, txtrain, xtest, ytest, txtest, label_names = dataset_mnist_color() - - tm = TMCoalesceMultiOuputClassifier(**params) - - print("Training with params: ") - pprint(params) - - for epoch in range(args.epochs): - print(f"Epoch {epoch}/{args.epochs}") - tm.fit(xtrain, ytrain, progress_bar=True) - pred = tm.predict(xtest, progress_bar=True) - - met = metrics(ytest, pred) - rep = classification_report(ytest, pred, target_names=label_names) - - pprint(met) - print(rep) - print("------------------------------") From d77671383c6e41dca9e66e1ca578f5abe023dd46 Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 12:25:04 +0100 Subject: [PATCH 09/12] Change default q to 1.0 in multioutput_classifier and remove tqdm and print statements --- .../models/multioutput_classifier.py | 38 ++----------------- 1 file changed, 4 insertions(+), 34 deletions(-) diff --git a/tmu/experimental/models/multioutput_classifier.py b/tmu/experimental/models/multioutput_classifier.py index 20f9161f..5975edf2 100644 --- a/tmu/experimental/models/multioutput_classifier.py +++ b/tmu/experimental/models/multioutput_classifier.py @@ -20,7 +20,6 @@ import numpy as np # import pandas as pd -from tqdm import tqdm from tmu.models.base import MultiWeightBankMixin, SingleClauseBankMixin, TMBaseModel from tmu.util.encoded_data_cache import DataEncoderCache @@ -52,7 +51,7 @@ def __init__( weighted_clauses=False, clause_drop_p=0.0, literal_drop_p=0.0, - q=-1, + q=1.0, seed=None, ): super().__init__( @@ -95,8 +94,6 @@ def init_clause_bank(self, X: np.ndarray, Y: np.ndarray): def init_weight_bank(self, X: np.ndarray, Y: np.ndarray): self.number_of_classes = Y.shape[1] - if self.q < 0: - self.q = max(1, self.number_of_classes - 1) / 2 self.weight_banks.set_clause_init( WeightBank, dict( @@ -114,7 +111,7 @@ def init_after(self, X: np.ndarray, Y: np.ndarray): if self.max_positive_clauses is None: self.max_positive_clauses = self.number_of_clauses - def fit(self, X, Y, shuffle=True, progress_bar=False, met=False, **kwargs): + def fit(self, X, Y, shuffle=True, **kwargs): self.init(X, Y) encoded_X_train = self.train_encoder_cache.get_encoded_data( @@ -154,8 +151,6 @@ def fit(self, X, Y, shuffle=True, progress_bar=False, met=False, **kwargs): if shuffle: self.rng.shuffle(shuffled_index) - pbar = tqdm(shuffled_index) if progress_bar else shuffled_index - # Combine all weight banks, to make use of faster numpy matrix operation self.wcomb = np.empty( (self.number_of_clauses, self.number_of_classes), dtype=np.int32 @@ -173,7 +168,7 @@ def fit(self, X, Y, shuffle=True, progress_bar=False, met=False, **kwargs): self.update_p_per_sample = np.empty((X.shape[0], self.number_of_classes)) # self.avg_n_neg_classes = 0 - for e in pbar: + for e in shuffled_index: clause_outputs = self.clause_bank.calculate_clause_outputs_update( self.literal_active, encoded_X_train, e ) @@ -259,30 +254,6 @@ def fit(self, X, Y, shuffle=True, progress_bar=False, met=False, **kwargs): self.wcomb[:, c] = self.weight_banks[c].get_weights() self.update_ps[c] = 0.0 self.nf[c] += 1 - # self.avg_n_neg_classes += 1 - # print( - # f"Average num of neg classes selected = {self.avg_n_neg_classes / Y.shape[0]}" - # ) - # print( - # pd.DataFrame( - # { - # "n_pos": self.pf, - # "n_neg": self.nf, - # "R": self.pf / self.nf, - # "n_lab": Y.sum(axis=0), - # "n_lab_e": Y.sum(axis=0) / Y.shape[0], - # "n_nlb": Y.shape[0] - Y.sum(axis=0), - # "n_nlb_e": (Y.shape[0] - Y.sum(axis=0)) / Y.shape[0], - # } - # ) - # ) - if met: - return { - "pf": self.pf, - "nf": self.nf, - "class_sums": self.class_sums_per_sample, - "update_p": self.update_p_per_sample, - } def predict( self, @@ -301,11 +272,10 @@ def predict( shuffled_index = np.arange(X.shape[0]) if shuffle: self.rng.shuffle(shuffled_index) - pbar = tqdm(shuffled_index) if progress_bar else shuffled_index # Compute class sums for all samples class_sums = np.empty((X.shape[0], self.number_of_classes)) - for e in pbar: + for e in shuffled_index: class_sums[e, :] = self.compute_class_sums( encoded_X_test, e, clip_class_sum ) From 62c0e6a62aa690bff8fb0246798543376e46fc20 Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 12:35:08 +0100 Subject: [PATCH 10/12] Change for numpy>=2.0 --- tmu/clause_bank/clause_bank.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tmu/clause_bank/clause_bank.py b/tmu/clause_bank/clause_bank.py index fb88bf3c..2b3c2af0 100644 --- a/tmu/clause_bank/clause_bank.py +++ b/tmu/clause_bank/clause_bank.py @@ -133,7 +133,7 @@ def initialize_clauses(self): order="c" ) - # TODO: MAKE issue for newer numpy version for np.uint32(-1) + # np.uint32(~0) will be deprecated in numpy>=2.0, changed to np.array(~0).astype(np.uint32) self.clause_bank[:, :, 0: self.number_of_state_bits_ta - 1] = np.array(~0).astype(np.uint32) self.clause_bank[:, :, self.number_of_state_bits_ta - 1] = 0 self.clause_bank = np.ascontiguousarray(self.clause_bank.reshape( @@ -143,6 +143,8 @@ def initialize_clauses(self): self.clause_bank_ind = np.empty( (self.number_of_clauses, self.number_of_ta_chunks, self.number_of_state_bits_ind), dtype=np.uint32) + + # np.uint32(~0) will be deprecated in numpy>=2.0, changed to np.array(~0).astype(np.uint32) self.clause_bank_ind[:, :, :] = np.array(~0).astype(np.uint32) self.clause_bank_ind = np.ascontiguousarray(self.clause_bank_ind.reshape( From 2eb9226bf9b10076968927a66f8298930225acd8 Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 12:38:37 +0100 Subject: [PATCH 11/12] Revert to_cpu function --- .../models/multioutput_classifier.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/tmu/experimental/models/multioutput_classifier.py b/tmu/experimental/models/multioutput_classifier.py index 5975edf2..deb98923 100644 --- a/tmu/experimental/models/multioutput_classifier.py +++ b/tmu/experimental/models/multioutput_classifier.py @@ -305,30 +305,6 @@ def compute_class_sums(self, encoded_X_test, ith_sample: int, clip_class_sum: bo class_sums = np.clip(class_sums, -self.T, self.T).astype(np.int32) return class_sums - def to_cpu(self): - if self.platform in ["GPU", "CUDA"]: - arr = np.empty((self.X_shape)) - clause_bank_gpu = self.clause_bank - clause_bank_gpu.synchronize_clause_bank() - clause_bank_type, clause_bank_args = self._build_cpu_bank(arr) - clause_bank_cpu = clause_bank_type(**clause_bank_args) - - clause_bank_cpu.clause_bank = clause_bank_gpu.clause_bank - clause_bank_cpu.clause_output = clause_bank_gpu.clause_output - clause_bank_cpu.literal_clause_count = clause_bank_gpu.literal_clause_count - - clause_bank_cpu._cffi_init() - - self.clause_bank = clause_bank_cpu - self.platform = "CPU" - print("to_cpu(): Successful....") - - elif self.platform == "CPU": - print("to_cpu(): Already CPU....") - - else: - print("to_cpu(): Not implemented....") - def clause_precision(self, the_class, X, Y): clause_outputs = self.transform(X) weights = self.weight_banks[the_class].get_weights() From 34eb9c71c7678e2e86821ae71da9f15c0d30f3cb Mon Sep 17 00:00:00 2001 From: Mayur1009 Date: Fri, 22 Nov 2024 12:41:37 +0100 Subject: [PATCH 12/12] Revert logging_example --- tmu/logging_example.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tmu/logging_example.json b/tmu/logging_example.json index c3314cbb..6d7c4afc 100644 --- a/tmu/logging_example.json +++ b/tmu/logging_example.json @@ -20,12 +20,12 @@ "console" ], "propagate": false - }, - "root": { - "level": "DEBUG", - "handlers": [ - "console" - ] } + }, + "root": { + "level": "DEBUG", + "handlers": [ + "console" + ] } }