diff --git a/src/openqdc/datasets/base.py b/src/openqdc/datasets/base.py index 91a8482..5ad18d4 100644 --- a/src/openqdc/datasets/base.py +++ b/src/openqdc/datasets/base.py @@ -205,6 +205,8 @@ def _remove_outliers( f"{avg_fn} is not a valid option, should be one of {list(BaseDataset.avg_options.keys())}" ) logger.info(f"Removing outliers outside {avg_fn} +/- {num_stds} stds") + formation_E /= self.data["n_atoms"] # convert to avg formation energy / atom + formation_E = np.squeeze(formation_E.T) # remove extra array dimension and transpose fn = BaseDataset.avg_options[avg_fn] mid = fn(formation_E) mask = np.logical_or(formation_E < mid - num_stds * formation_E.std(), formation_E > mid + num_stds * formation_E.std()) @@ -232,7 +234,7 @@ def _precompute_E(self): # remove outliers if requested in __init__ if self.remove_outliers: - E = self._remove_outliers(np.squeeze(E.T), + E = self._remove_outliers(E, avg_fn=self.avg_fn, num_stds=self.num_stds)