Supports
pip install hrr
- v1.2.3 - support for real-valued FFT, can be accessed by
HRR.real
- v1.1.0 -
dim/axis
support for PyTorch, JAX & Flax - For TensorFlow binding/unbinding can only be applied to the last dimension
Holographic Reduced Representations (HRR) is a method of representing compositional structures using circular convolution in distributed representations. The HRR operations binding and unbinding allow assigning abstract concepts to arbitrary numerical vectors. Given vectors x and y in a d-dimensional space, both can be combined using binding operation. Likewise, one of the vectors can be retrieved knowing one of the two vectors using unbinding operation.
HRR library supports TensorFlow, PyTorch
, JAX, and Flax. To import the
HRR package with the TensorFlow backend use HRR.with_tensorflow
, to import with the JAX backend
use HRR.with_jax
, and so on. Vectors are sampled from a normal distribution with zero mean and the variance of the
inverse of the dimension using normal
function, with projection
onto the ball of complex unit magnitude, to
enforce that the inverse will be numerically stable during unbinding, proposed
in Learning with Holographic Reduced Representations.
from HRR.with_pytorch import normal, projection, binding, unbinding, cosine_similarity
batch = 32
features = 256
x = projection(normal(shape=(batch, features), seed=0), dim=-1)
y = projection(normal(shape=(batch, features), seed=1), dim=-1)
b = binding(x, y, dim=-1)
y_prime = unbinding(b, x, dim=-1)
score = cosine_similarity(y, y_prime, dim=-1, keepdim=False)
print('score:', score[0])
# prints score: tensor(1.0000)
What makes HRR more interesting is that multiple vectors can be combined by element-wise addition of the vectors, however, retrieval accuracy will decrease.
x = projection(normal(shape=(batch, features), seed=0), dim=-1)
y = projection(normal(shape=(batch, features), seed=1), dim=-1)
w = projection(normal(shape=(batch, features), seed=2), dim=-1)
z = projection(normal(shape=(batch, features), seed=3), dim=-1)
b = binding(x, y, dim=-1) + binding(w, z, dim=-1)
y_prime = unbinding(b, x, dim=-1)
score = cosine_similarity(y, y_prime, dim=-1, keepdim=False)
print('score:', score[0])
# prints score: tensor(0.7483)
More interestingly, vectors can be combined and retrieved in hierarchical order. π³
x y
\ /
\/
b=x#y z
\ /
\/
c=(x#y)#z
x = projection(normal(shape=(batch, features), seed=0), dim=-1)
y = projection(normal(shape=(batch, features), seed=1), dim=-1)
z = projection(normal(shape=(batch, features), seed=2), dim=-1)
b = binding(x, y, dim=-1)
c = binding(b, z, dim=-1)
b_ = unbinding(c, z, dim=-1)
y_ = unbinding(b_, x, dim=-1)
score = cosine_similarity(y, y_, dim=-1)
print('score:', score[0])
# prints score: tensor(1.0000)
HRR package supports vector binding/unbinding as a Flax module. Like any other Flax module, this needs to be initialized first and then execute using the apply method.
x = normal(shape=(batch, features), seed=0)
y = normal(shape=(batch, features), seed=1)
class Model(nn.Module):
def setup(self):
self.binding = Binding()
self.unbinding = Unbinding()
self.projection = Projection()
self.similarity = CosineSimilarity()
@nn.compact
def __call__(self, x, y, axis):
x = self.projection(x, axis=axis)
y = self.projection(y, axis=axis)
b = self.binding(x, y, axis=axis)
y_ = self.unbinding(b, x, axis=axis)
return self.similarity(y, y_, axis=axis, keepdims=False)
model = Model()
init_value = {'x': np.ones_like(x), 'y': np.ones_like(y), 'axis': -1}
var = model.init(jax.random.PRNGKey(0), **init_value)
tic = time.time()
inputs = {'x': x, 'y': y, 'axis': -1}
score = model.apply(var, **inputs)
toc = time.time()
print(score)
print(f'score: {score[0]:.2f}')
print(f'Total time: {toc - tic:.4f}s')
# prints score: 1.00
# Total time: 0.0088s
apply.py
shows an example of how to apply binding/unbinding to an image. The bound image is the composite representation of the original image and another matrix sampled from a normal distribution performed by the binding
operation. Using the unbinding
operation, the original image can be retrieved without any loss.
π± Deploying Convolutional Networks on Untrusted Platforms Using 2D Holographic Reduced Representations
GitHub
@inproceedings{Alam2022,
archivePrefix = {arXiv},
arxivId = {2206.05893},
author = {Alam, Mohammad Mahmudul and Raff, Edward and Oates, Tim and Holt, James},
booktitle = {International Conference on Machine Learning},
eprint = {2206.05893},
title = {{Deploying Convolutional Networks on Untrusted Platforms Using 2D Holographic Reduced Representations}},
url = {http://arxiv.org/abs/2206.05893},
year = {2022}
}
π± Recasting Self-Attention with Holographic Reduced Representations
GitHub
@inproceedings{alam2023recasting,
title={Recasting self-attention with holographic reduced representations},
author={Alam, Mohammad Mahmudul and Raff, Edward and Biderman, Stella and Oates, Tim and Holt, James},
booktitle={International Conference on Machine Learning},
pages={490--507},
year={2023},
organization={PMLR}
}
π± Towards Generalization in Subitizing with Neuro-Symbolic Loss using Holographic Reduced Representations
GitHub
@article{alam2023towards,
title={Towards generalization in subitizing with neuro-symbolic loss using holographic reduced representations},
author={Alam, Mohammad Mahmudul and Raff, Edward and Oates, Tim},
journal={arXiv preprint arXiv:2312.15310},
year={2023}
}
π± Holographic Global Convolutional Networks for Long-Range Prediction Tasks in Malware Detection @ AISTATS 2024
GitHub
@inproceedings{alam2024holographic,
title={Holographic Global Convolutional Networks for Long-Range Prediction Tasks in Malware Detection},
author={Alam, Mohammad Mahmudul and Raff, Edward and Biderman, Stella R and Oates, Tim and Holt, James},
booktitle={International Conference on Artificial Intelligence and Statistics},
pages={4042--4050},
year={2024},
organization={PMLR}
}
π± A Walsh Hadamard Derived Linear Vector Symbolic Architecture @ NeurIPS 2024
GitHub
@article{alam2024walsh,
title={A Walsh Hadamard Derived Linear Vector Symbolic Architecture},
author={Alam, Mohammad Mahmudul and Oberle, Alexander and Raff, Edward and Biderman, Stella and Oates, Tim and Holt, James},
journal={arXiv preprint arXiv:2410.22669},
year={2024}
}
To report a bug or any other questions, please feel free to open an issue.
Thanks to @EdwardRaffML and @oatesbag for their constant support to this research endeavor.