Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move escn to deprecated #851

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/fairchem/core/models/escn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations

from .escn import eSCN
from .escn_deprecated import eSCN

__all__ = ["eSCN"]
163 changes: 14 additions & 149 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import contextlib
import logging
import time
import typing

import torch
Expand Down Expand Up @@ -39,8 +38,8 @@
from e3nn import o3


@registry.register_model("escn")
class eSCN(nn.Module, GraphModelMixin):
@registry.register_model("escn_backbone")
class eSCNBackbone(nn.Module, BackboneInterface, GraphModelMixin):
"""Equivariant Spherical Channel Network
Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs

Expand Down Expand Up @@ -228,17 +227,14 @@ def __init__(
self.sphharm_weights = nn.ParameterList(sphharm_weights)

@conditional_grad(torch.enable_grad())
def forward(self, data):
def forward(self, data: Batch) -> dict[str, torch.Tensor]:
device = data.pos.device
self.batch_size = len(data.natoms)
self.dtype = data.pos.dtype

start_time = time.time()
atomic_numbers = data.atomic_numbers.long()
assert (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this back?

atomic_numbers.max().item() < self.max_num_elements
), "Atomic number exceeds that given in model config"
num_atoms = len(atomic_numbers)

graph = self.generate_graph(data)

###############################################################
Expand Down Expand Up @@ -333,32 +329,12 @@ def forward(self, data):

x_pt = x_pt.view(-1, self.sphere_channels_all)

###############################################################
# Energy estimation
###############################################################
node_energy = self.energy_block(x_pt)
energy = torch.zeros(len(data.natoms), device=device)
energy.index_add_(0, data.batch, node_energy.view(-1))
# Scale energy to help balance numerical precision w.r.t. forces
energy = energy * 0.001

outputs = {"energy": energy}
###############################################################
# Force estimation
###############################################################
if self.regress_forces:
forces = self.force_block(x_pt, self.sphere_points)
outputs["forces"] = forces

if self.show_timing_info is True:
torch.cuda.synchronize()
logging.info(
f"{self.counter} Time: {time.time() - start_time}\tMemory: {len(data.pos)}\t{torch.cuda.max_memory_allocated() / 1000000}"
)

self.counter = self.counter + 1

return outputs
return {
"sphere_values": x_pt,
"sphere_points": self.sphere_points,
"node_embedding": x,
"graph": graph,
}

# Initialize the edge rotation matrics
def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec):
Expand Down Expand Up @@ -422,122 +398,9 @@ def num_params(self) -> int:
return sum(p.numel() for p in self.parameters())


@registry.register_model("escn_backbone")
class eSCNBackbone(eSCN, BackboneInterface):
@conditional_grad(torch.enable_grad())
def forward(self, data: Batch) -> dict[str, torch.Tensor]:
device = data.pos.device
self.batch_size = len(data.natoms)
self.dtype = data.pos.dtype

atomic_numbers = data.atomic_numbers.long()
num_atoms = len(atomic_numbers)

graph = self.generate_graph(data)

###############################################################
# Initialize data structures
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
self.SO3_edge_rot = nn.ModuleList()
for i in range(self.num_resolutions):
self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i]))

###############################################################
# Initialize node embeddings
###############################################################

# Init per node representations using an atomic number based embedding
offset = 0
x = SO3_Embedding(
num_atoms,
self.lmax_list,
self.sphere_channels,
device,
self.dtype,
)

offset_res = 0
offset = 0
# Initialize the l=0,m=0 coefficients for each resolution
for i in range(self.num_resolutions):
x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[
:, offset : offset + self.sphere_channels
]
offset = offset + self.sphere_channels
offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2)

# This can be expensive to compute (not implemented efficiently), so only do it once and pass it along to each layer
mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list, device)

###############################################################
# Update spherical node embeddings
###############################################################

for i in range(self.num_layers):
if i > 0:
x_message = self.layer_blocks[i](
x,
atomic_numbers,
graph.edge_distance,
graph.edge_index,
self.SO3_edge_rot,
mappingReduced,
)

# Residual layer for all layers past the first
x.embedding = x.embedding + x_message.embedding

else:
# No residual for the first layer
x = self.layer_blocks[i](
x,
atomic_numbers,
graph.edge_distance,
graph.edge_index,
self.SO3_edge_rot,
mappingReduced,
)

# Sample the spherical channels (node embeddings) at evenly distributed points on the sphere.
# These values are fed into the output blocks.
x_pt = torch.tensor([], device=device)
offset = 0
# Compute the embedding values at every sampled point on the sphere
for i in range(self.num_resolutions):
num_coefficients = int((x.lmax_list[i] + 1) ** 2)
x_pt = torch.cat(
[
x_pt,
torch.einsum(
"abc, pb->apc",
x.embedding[:, offset : offset + num_coefficients],
self.sphharm_weights[i],
).contiguous(),
],
dim=2,
)
offset = offset + num_coefficients

x_pt = x_pt.view(-1, self.sphere_channels_all)

return {
"sphere_values": x_pt,
"sphere_points": self.sphere_points,
"node_embedding": x,
"graph": graph,
}


@registry.register_model("escn_energy_head")
class eSCNEnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone, reduce = "sum"):
def __init__(self, backbone, reduce="sum"):
super().__init__()
backbone.energy_block = None
self.reduce = reduce
Expand All @@ -558,7 +421,9 @@ def forward(
elif self.reduce == "mean":
return {"energy": energy / data.natoms}
else:
raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}")
raise ValueError(
f"reduce can only be sum or mean, user provided: {self.reduce}"
)


@registry.register_model("escn_force_head")
Expand Down
Loading