From 94e4a7f9c3fd932a8adfb73e7c47158b63f44fe9 Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:52:34 -0700 Subject: [PATCH] add option to apply mean to energy (#818) * add option to apply mean to energy * lint * change name to reduce * update --- .../core/models/equiformer_v2/equiformer_v2.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 61b62be16..2851acbc2 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -610,9 +610,9 @@ def no_weight_decay(self) -> set: @registry.register_model("equiformer_v2_energy_head") class EquiformerV2EnergyHead(nn.Module, HeadInterface): - def __init__(self, backbone): + def __init__(self, backbone, reduce: str="sum"): super().__init__() - + self.reduce = reduce self.avg_num_nodes = backbone.avg_num_nodes self.energy_block = FeedForwardNetwork( backbone.sphere_channels, @@ -638,8 +638,15 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): device=node_energy.device, dtype=node_energy.dtype, ) + energy.index_add_(0, data.batch, node_energy.view(-1)) - return {"energy": energy / self.avg_num_nodes} + if self.reduce == "sum": + return {"energy": energy / self.avg_num_nodes} + elif self.reduce == "mean": + return {"energy": energy / data.natoms} + else: + raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}") + @registry.register_model("equiformer_v2_force_head")