-
Notifications
You must be signed in to change notification settings - Fork 18
Open
Description
Hi. I find the way you calculate radial is different from other similar works, e.g., EGNN.
Your strategy. the radial is the dot product of the coord differences.
def coord2radial(edge_index, coord):
row, col = edge_index
coord_diff = coord[row] - coord[col] # [n_edge, n_channel, d]
radial = torch.bmm(coord_diff, coord_diff.transpose(-1, -2)) # [n_edge, n_channel, n_channel]
# normalize radial
radial = F.normalize(radial, dim=0) # [n_edge, n_channel, n_channel]
return radial, coord_diffEGNN's strategy. the radial is the squared distance between two nodes.
def coord2radial(self, edge_index, coord):
row, col = edge_index
coord_diff = coord[row] - coord[col]
radial = torch.sum(coord_diff**2, 1).unsqueeze(1)
if self.normalize:
norm = torch.sqrt(radial).detach() + self.epsilon
coord_diff = coord_diff / norm
return radial, coord_diffI think your radial can represent the orientation of two multi-channel residues, and egnn's radial represents the distance. Is this reasonable? What do you think it represents? What's your motivation for defining it this way instead of following egnn?
The way you normalize the radial is quite interesting, you normalize it along the n_edge dimension (similar to "batch dimension").
Why? Have you tried removing normalization?
Best,
Zhangzhi
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels