Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hypergraph: Add equations in docstrings of forward methods #194

Merged
merged 2 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion test/nn/simplicial/test_san.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def test_forward(self):

assert torch.any(
torch.isclose(
model(x, laplacian_up_1, laplacian_down_1)[0], torch.tensor([0.7727, 0.2389]), rtol=1e-02
model(x, laplacian_up_1, laplacian_down_1)[0],
torch.tensor([0.7727, 0.2389]),
rtol=1e-02,
)
)

Expand Down
14 changes: 14 additions & 0 deletions topomodelx/nn/hypergraph/dhgcn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,20 @@ def forward(self, x_0):

Dynamic topology module of the DHST Block is implemented here.

.. math::
\begin{align*}
&🟧 \quad m_{\rightarrow z}^{\rightarrow 1} = \text{AGG}\_{y \in \mathcal{B}(z)}(h_y^{0, t})\\
&🟦 \quad h_z^{1, t+1} = \sigma(m_{\rightarrow z}^{\rightarrow 1})\\
&🟥 \quad m_{z \rightarrow x}^{1 \rightarrow 0} = M_\mathcal{B}(att(h_z^{1, t+1}), h_z^{1, t+1})\\
&🟧 \quad m_{\rightarrow x}^{\rightarrow 0} = \sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{0\rightarrow 1}\\
&🟦 \quad {h_x^{0, t+1}} = \text{MLP}(m_{\rightarrow x}^{\rightarrow 0})
\end{align*}

References
----------
.. [TNN23] Equations of Topological Neural Networks.
https://github.com/awesome-tnns/awesome-tnns/

Parameters
----------
x_0 : torch.Tensor, shape=[n_nodes, node_channels]
Expand Down
20 changes: 19 additions & 1 deletion topomodelx/nn/hypergraph/hmpnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,32 @@ def forward(self, x, m):


class HMPNNLayer(nn.Module):
"""HMPNN Layer introduced in Heydari et Livi 2022.
r"""HMPNN Layer introduced in Heydari et Livi 2022.

The layer is a hypergraph comprised of nodes and hyperedges that makes their new reprsentation using the input
representation and the messages passed between them. In this layer, the message passed from a node to its
neighboring hyperedges is only a function of its input representation, but the message from a hyperedge to its
neighboring nodes is also a function of the messages recieved from them beforehand. This way, a node could have
a more explicit effect on its upper adjacent neighbors i.e. the nodes that it share a hyperedge with.

.. math::
\begin{align*}
&🟥 \quad m_{{y \rightarrow z}}^{(0 \rightarrow 1)} = M_\mathcal{C} (h_y^{t,(0)}, h_z^{t, (1)})\\
&🟧 \quad m_{z'}^{(0 \rightarrow 1)} = AGG'{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0\rightarrow1)}\\
&🟧 \quad m_{z}^{(0 \rightarrow 1)} = AGG_{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0 \rightarrow 1)}\\
&🟥 \quad m_{z \rightarrow x}^{(1 \rightarrow0)} = M_\mathcal{B}(h_z^{t,(1)}, m_z^{(1)})\\
&🟧 \quad m_x^{(1 \rightarrow0)} = AGG_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1 \rightarrow0)}\\
&🟩 \quad m_x^{(0)} = m_x^{(1 \rightarrow 0)}\\
&🟩 \quad m_z^{(1)} = m_{z'}^{(0 \rightarrow 1)}\\
&🟦 \quad h_x^{t+1, (0)} = U^{(0)}(h_x^{t,(0)}, m_x^{(0)})\\
&🟦 \quad h_z^{t+1,(1)} = U^{(1)}(h_z^{t,(1)}, m_{z}^{(1)})
\end{align*}

References
----------
.. [TNN23] Equations of Topological Neural Networks.
https://github.com/awesome-tnns/awesome-tnns/

Parameters
----------
in_features : int
Expand Down
15 changes: 15 additions & 0 deletions topomodelx/nn/hypergraph/hypergat_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,21 @@ def update(self, x_message_on_target, x_target=None):
def forward(self, x_source, incidence):
r"""Forward pass.

.. math::
\begin{align*}
&🟥 \quad m_{y \rightarrow z}^{(0 \rightarrow 1) } = (B^T_1\odot att(h_{y \in \mathcal{B}(z)}^{t,(0)}))\_{zy} \cdot h^{t,(0)}y \cdot \Theta^{t,(0)}\\
&🟧 \quad m_z^{(1)} = \sigma(\sum_{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0 \rightarrow 1)})\\
&🟥 \quad m_{z \rightarrow x}^{(1 \rightarrow 0)} = (B_1 \odot att(h_{z \in \mathcal{C}(x)}^{t,(1)}))\_{xz} \cdot m_{z}^{(1)} \cdot \Theta^{t,(1)}\\
&🟧 \quad m_{x}^{(0)} = \sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1\rightarrow0)}\\
&🟩 \quad m_x = m_{x}^{(0)}\\
&🟦 \quad h_x^{t+1, (0)} = \sigma(m_x)
\end{align*}

References
----------
.. [TNN23] Equations of Topological Neural Networks.
https://github.com/awesome-tnns/awesome-tnns/

Parameters
----------
x : torch.Tensor
Expand Down
15 changes: 15 additions & 0 deletions topomodelx/nn/hypergraph/hypersage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ def aggregate(self, x_messages: torch.Tensor, mode: str = "intra"):
def forward(self, x: torch.Tensor, incidence: torch.Tensor): # type: ignore[override]
r"""Forward pass.

.. math::
\begin{align*}
&🟥 \quad m_{y \rightarrow z}^{(0 \rightarrow 1)} = (B_1)^T_{zy} \cdot w_y \cdot (h_y^{(0)})^p\\
&🟥 \quad m_z^{(0 \rightarrow 1)} = \left(\frac{1}{\vert \mathcal{B}(z)\vert}\sum_{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0 \rightarrow 1)}\right)^{\frac{1}{p}}\\
&🟥 \quad m_{z \rightarrow x}^{(1 \rightarrow 0)} = (B_1)_{xz} \cdot w_z \cdot (m_z^{(0 \rightarrow 1)})^p\\
&🟧 \quad m_x^{(1,0)} = \left(\frac{1}{\vert \mathcal{C}(x) \vert}\sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1 \rightarrow 0)}\right)^{\frac{1}{p}}\\
&🟩 \quad m_x^{(0)} = m_x^{(1 \rightarrow 0)}\\
&🟦 \quad h_x^{t+1, (0)} = \sigma \left(\frac{m_x^{(0)} + h_x^{t,(0)}}{\lvert m_x^{(0)} + h_x^{t,(0)}\rvert} \cdot \Theta^t\right)
\end{align*}

References
----------
.. [TNN23] Equations of Topological Neural Networks.
https://github.com/awesome-tnns/awesome-tnns/

Parameters
----------
x : torch.Tensor
Expand Down