Skip to content

Commit

Permalink
fix magnitude issue
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Feb 19, 2025
1 parent af26f6e commit a6ffc72
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions tutorials/examples/train_graph_ring.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor:

# This is n_nodes ** 2, for each graph.
edge_index = torch.einsum("bnf,bmf->bnm", edge_feature, edge_feature)
edge_index = edge_index / torch.sqrt(torch.tensor(self.hidden_dim))

# Undirected.
if self.is_directed:
Expand Down

0 comments on commit a6ffc72

Please sign in to comment.