Skip to content

Commit

Permalink
add option to apply mean to energy (#818)
Browse files Browse the repository at this point in the history
* add option to apply mean to energy

* lint

* change name to reduce

* update
  • Loading branch information
rayg1234 committed Aug 22, 2024
1 parent c2b0c30 commit 94e4a7f
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,9 @@ def no_weight_decay(self) -> set:

@registry.register_model("equiformer_v2_energy_head")
class EquiformerV2EnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone):
def __init__(self, backbone, reduce: str="sum"):
super().__init__()

self.reduce = reduce
self.avg_num_nodes = backbone.avg_num_nodes
self.energy_block = FeedForwardNetwork(
backbone.sphere_channels,
Expand All @@ -638,8 +638,15 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
device=node_energy.device,
dtype=node_energy.dtype,
)

energy.index_add_(0, data.batch, node_energy.view(-1))
return {"energy": energy / self.avg_num_nodes}
if self.reduce == "sum":
return {"energy": energy / self.avg_num_nodes}
elif self.reduce == "mean":
return {"energy": energy / data.natoms}
else:
raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}")



@registry.register_model("equiformer_v2_force_head")
Expand Down

0 comments on commit 94e4a7f

Please sign in to comment.