Skip to content

Commit

Permalink
Merge pull request #291 from pyt-team/update_readme_2
Browse files Browse the repository at this point in the history
making code in the  readme more clear
  • Loading branch information
ffl096 authored Oct 21, 2024
2 parents 106db16 + 26317e0 commit 0f99fe6
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ from topomodelx.nn.simplicial.san import SAN
from topomodelx.utils.sparse import from_sparse

# Step 1: Load the Karate Club dataset
dataset = tnx.datasets.karate_club(complex_type="simplicial")
karate_club_complex = tnx.datasets.karate_club(complex_type="simplicial")

# Step 2: Prepare Laplacians and node/edge features
laplacian_down = from_sparse(dataset.down_laplacian_matrix(rank=1))
laplacian_up = from_sparse(dataset.up_laplacian_matrix(rank=1))
incidence_0_1 = from_sparse(dataset.incidence_matrix(rank=1))
laplacian_down = from_sparse(karate_club_complex.down_laplacian_matrix(rank=1))
laplacian_up = from_sparse(karate_club_complex.up_laplacian_matrix(rank=1))
incidence_0_1 = from_sparse(karate_club_complex.incidence_matrix(rank=1))

x_0 = torch.tensor(np.stack(list(dataset.get_simplex_attributes("node_feat").values())))
x_1 = torch.tensor(np.stack(list(dataset.get_simplex_attributes("edge_feat").values())))
x_0 = torch.tensor(np.stack(list(karate_club_complex.get_simplex_attributes("node_feat").values())))
x_1 = torch.tensor(np.stack(list(karate_club_complex.get_simplex_attributes("edge_feat").values())))
x = x_1 + torch.sparse.mm(incidence_0_1.T, x_0)

# Step 3: Define the network
class Network(torch.nn.Module):
class TNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.base_model = SAN(in_channels, hidden_channels, n_layers=2)
Expand All @@ -73,7 +73,7 @@ class Network(torch.nn.Module):
return torch.sigmoid(self.linear(x))

# Step 4: Initialize the network and perform a forward pass
model = Network(in_channels=x.shape[-1], hidden_channels=16, out_channels=2)
model = TNN(in_channels=x.shape[-1], hidden_channels=16, out_channels=2)
y_hat_edge = model(x, laplacian_up=laplacian_up, laplacian_down=laplacian_down)
```

Expand Down

0 comments on commit 0f99fe6

Please sign in to comment.