diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 61b62be16..2851acbc2 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -610,9 +610,9 @@ def no_weight_decay(self) -> set: @registry.register_model("equiformer_v2_energy_head") class EquiformerV2EnergyHead(nn.Module, HeadInterface): - def __init__(self, backbone): + def __init__(self, backbone, reduce: str="sum"): super().__init__() - + self.reduce = reduce self.avg_num_nodes = backbone.avg_num_nodes self.energy_block = FeedForwardNetwork( backbone.sphere_channels, @@ -638,8 +638,15 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): device=node_energy.device, dtype=node_energy.dtype, ) + energy.index_add_(0, data.batch, node_energy.view(-1)) - return {"energy": energy / self.avg_num_nodes} + if self.reduce == "sum": + return {"energy": energy / self.avg_num_nodes} + elif self.reduce == "mean": + return {"energy": energy / data.natoms} + else: + raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}") + @registry.register_model("equiformer_v2_force_head")