Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix edge_rot_mat initalization #847

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 0 additions & 55 deletions src/fairchem/core/models/equiformer_v2/edge_rot_mat.py

This file was deleted.

17 changes: 6 additions & 11 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
GraphModelMixin,
HeadInterface,
)
from fairchem.core.models.escn.edge_rot_mat import init_edge_rot_mat
from fairchem.core.models.scn.smearing import GaussianSmearing

with contextlib.suppress(ImportError):
Expand All @@ -23,7 +24,6 @@

import typing

from .edge_rot_mat import init_edge_rot_mat
from .gaussian_rbf import GaussianRadialBasisLayer
from .input_block import EdgeDegreeEmbedding
from .layer_norm import (
Expand Down Expand Up @@ -443,9 +443,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)
edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
for i in range(self.num_resolutions):
Expand Down Expand Up @@ -569,10 +567,6 @@ def _init_gp_partitions(
edge_distance_vec,
)

# Initialize the edge rotation matrics
def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec):
return init_edge_rot_mat(edge_distance_vec)

@property
def num_params(self):
return sum(p.numel() for p in self.parameters())
Expand Down Expand Up @@ -610,7 +604,7 @@ def no_weight_decay(self) -> set:

@registry.register_model("equiformer_v2_energy_head")
class EquiformerV2EnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone, reduce: str="sum"):
def __init__(self, backbone, reduce: str = "sum"):
super().__init__()
self.reduce = reduce
self.avg_num_nodes = backbone.avg_num_nodes
Expand Down Expand Up @@ -645,8 +639,9 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
elif self.reduce == "mean":
return {"energy": energy / data.natoms}
else:
raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}")

raise ValueError(
f"reduce can only be sum or mean, user provided: {self.reduce}"
)


@registry.register_model("equiformer_v2_force_head")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
pass


from fairchem.core.models.escn.edge_rot_mat import init_edge_rot_mat

from .edge_rot_mat import init_edge_rot_mat
from .gaussian_rbf import GaussianRadialBasisLayer
from .input_block import EdgeDegreeEmbedding
from .layer_norm import (
Expand Down Expand Up @@ -485,9 +485,7 @@ def forward(self, data):
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)
edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
for i in range(self.num_resolutions):
Expand Down Expand Up @@ -619,10 +617,6 @@ def forward(self, data):

return outputs

# Initialize the edge rotation matrics
def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec):
return init_edge_rot_mat(edge_distance_vec)

@property
def num_params(self):
return sum(p.numel() for p in self.parameters())
Expand Down
69 changes: 69 additions & 0 deletions src/fairchem/core/models/escn/edge_rot_mat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

import logging
import math

import torch


# Algorithm from Ken Whatmough (https://math.stackexchange.com/users/918128/ken-whatmough)
def vec3_to_perp_vec3(v):
"""
Small proof:
input = x y z
output = s(x)|z| s(y)|z| -s(z)(|x|+|y|)

input dot output
= x*s(x)*|z| + y*s(y)*|z| - z*s(z)*|x| - z*s(z)*|y|
a*s(a)=|a| ,
= |x|*|z| + |y|*|z| - |z|*|x| - |z|*|y| = 0

"""
return torch.hstack(
[
v[:, [2]].copysign(v[:, [0, 1]]),
-v[:, [0, 1]].copysign(v[:, [2]]).sum(axis=1, keepdim=True),
]
)


# https://en.wikipedia.org/wiki/Rodrigues'_rotation_formula#Matrix_notation
def vec3_rotate_around_axis(v, axis, thetas):
# v_rot= v + (sTheta)*(axis X v) + (1-cTheta)*(axis X (axis X v))
Kv = torch.cross(axis, v, dim=1)
KKv = torch.cross(axis, Kv, dim=1)
s_theta = torch.sin(thetas)
c_theta = torch.cos(thetas)
return v + s_theta * Kv + (1 - c_theta) * KKv


def init_edge_rot_mat(edge_distance_vec):
edge_vec_0 = edge_distance_vec.detach()
edge_vec_0_distance = torch.linalg.norm(edge_vec_0, axis=1, keepdim=True)

# Make sure the atoms are far enough apart
# assert torch.min(edge_vec_0_distance) < 0.0001
if torch.min(edge_vec_0_distance) < 0.0001:
logging.error(f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}")

norm_x = edge_vec_0 / edge_vec_0_distance

perp_to_norm_x = vec3_to_perp_vec3(norm_x)
random_rotated_in_plane_perp_to_norm_x = vec3_rotate_around_axis(
perp_to_norm_x,
norm_x,
torch.rand((norm_x.shape[0], 1), device=norm_x.device) * 2 * math.pi,
)

norm_z = random_rotated_in_plane_perp_to_norm_x / torch.linalg.norm(
random_rotated_in_plane_perp_to_norm_x, axis=1, keepdim=True
)

norm_y = torch.cross(norm_x, norm_z, dim=1)
norm_y /= torch.linalg.norm(norm_y, dim=1, keepdim=True)

# Construct the 3D rotation matrix
norm_x = norm_x.view(-1, 1, 3)
norm_y = -norm_y.view(-1, 1, 3)
norm_z = norm_z.view(-1, 1, 3)
return torch.cat([norm_z, norm_x, norm_y], dim=1).contiguous()
73 changes: 7 additions & 66 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fairchem.core.common.registry import registry
from fairchem.core.common.utils import conditional_grad
from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface
from fairchem.core.models.escn.edge_rot_mat import init_edge_rot_mat
from fairchem.core.models.escn.so3 import (
CoefficientMapping,
SO3_Embedding,
Expand Down Expand Up @@ -246,9 +247,7 @@ def forward(self, data):
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)
edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
self.SO3_edge_rot = nn.ModuleList()
Expand All @@ -268,7 +267,6 @@ def forward(self, data):
device,
self.dtype,
)

offset_res = 0
offset = 0
# Initialize the l=0,m=0 coefficients for each resolution
Expand Down Expand Up @@ -360,63 +358,6 @@ def forward(self, data):

return outputs

# Initialize the edge rotation matrics
def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec):
edge_vec_0 = edge_distance_vec
edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1))

# Make sure the atoms are far enough apart
if torch.min(edge_vec_0_distance) < 0.0001:
logging.error(
f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}"
)
(minval, minidx) = torch.min(edge_vec_0_distance, 0)
logging.error(
f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}"
)

norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1))

edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5
edge_vec_2 = edge_vec_2 / (
torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1)
)
# Create two rotated copys of the random vectors in case the random vector is aligned with norm_x
# With two 90 degree rotated vectors, at least one should not be aligned with norm_x
edge_vec_2b = edge_vec_2.clone()
edge_vec_2b[:, 0] = -edge_vec_2[:, 1]
edge_vec_2b[:, 1] = edge_vec_2[:, 0]
edge_vec_2c = edge_vec_2.clone()
edge_vec_2c[:, 1] = -edge_vec_2[:, 2]
edge_vec_2c[:, 2] = edge_vec_2[:, 1]
vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1)
vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1)

vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1)
edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2)
vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1)
edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2)

vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1))
# Check the vectors aren't aligned
assert torch.max(vec_dot) < 0.99

norm_z = torch.cross(norm_x, edge_vec_2, dim=1)
norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True)))
norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1))
norm_y = torch.cross(norm_x, norm_z, dim=1)
norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True)))

# Construct the 3D rotation matrix
norm_x = norm_x.view(-1, 3, 1)
norm_y = -norm_y.view(-1, 3, 1)
norm_z = norm_z.view(-1, 3, 1)

edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2)
edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2)

return edge_rot_mat.detach()

@property
def num_params(self) -> int:
return sum(p.numel() for p in self.parameters())
Expand All @@ -440,9 +381,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)
edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
self.SO3_edge_rot = nn.ModuleList()
Expand Down Expand Up @@ -537,7 +476,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:

@registry.register_model("escn_energy_head")
class eSCNEnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone, reduce = "sum"):
def __init__(self, backbone, reduce="sum"):
super().__init__()
backbone.energy_block = None
self.reduce = reduce
Expand All @@ -558,7 +497,9 @@ def forward(
elif self.reduce == "mean":
return {"energy": energy / data.natoms}
else:
raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}")
raise ValueError(
f"reduce can only be sum or mean, user provided: {self.reduce}"
)


@registry.register_model("escn_force_head")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# ---
# name: TestEquiformerV2.test_ddp.1
Approx(
array([0.12408739], dtype=float32),
array([-0.00897979], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -19,7 +19,7 @@
# ---
# name: TestEquiformerV2.test_ddp.3
Approx(
array([ 1.4928584e-03, -7.4167408e-05, 2.9909366e-03], dtype=float32),
array([-0.00893646, -0.00290753, -0.02622171], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -31,7 +31,7 @@
# ---
# name: TestEquiformerV2.test_energy_force_shape.1
Approx(
array([0.12408739], dtype=float32),
array([-0.00897979], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -44,7 +44,7 @@
# ---
# name: TestEquiformerV2.test_energy_force_shape.3
Approx(
array([ 1.4928584e-03, -7.4167408e-05, 2.9909366e-03], dtype=float32),
array([-0.00893646, -0.00290753, -0.02622171], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -56,7 +56,7 @@
# ---
# name: TestEquiformerV2.test_gp.1
Approx(
array([0.12408739], dtype=float32),
array([-0.02495255], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -69,7 +69,7 @@
# ---
# name: TestEquiformerV2.test_gp.3
Approx(
array([ 1.4928661e-03, -7.4134863e-05, 2.9909245e-03], dtype=float32),
array([ 0.00203054, -0.00042871, -0.00279118], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand Down