Skip to content

Commit

Permalink
changes based on comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikhil Shenoy committed Apr 6, 2024
1 parent 7493273 commit ed73e7d
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 38 deletions.
3 changes: 0 additions & 3 deletions openqdc/datasets/interaction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,13 @@ def __getitem__(self, idx: int):
forces = self._convert_array(np.array(self.data["forces"][p_start:p_end], dtype=np.float32))

e0 = self._convert_array(np.array(self.__isolated_atom_energies__[..., z, c + shift].T, dtype=np.float32))
formation_energies = energies - e0.sum(axis=0)

bunch = Bunch(
positions=positions,
atomic_numbers=z,
charges=c,
e0=e0,
energies=energies,
formation_energies=formation_energies,
per_atom_formation_energies=formation_energies / len(z),
name=name,
subset=subset,
forces=forces,
Expand Down
7 changes: 1 addition & 6 deletions openqdc/datasets/interaction/des370k.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,17 @@ def _read_raw_entries(cls) -> List[Dict]:
logger.info(f"Reading {cls._name} interaction data from {filepath}")
df = pd.read_csv(filepath)
data = []
for idx, row in tqdm(df.iterrows(), total=df.shape[0]):
for _, row in tqdm(df.iterrows(), total=df.shape[0]):
smiles0, smiles1 = row["smiles0"], row["smiles1"]
charge0, charge1 = row["charge0"], row["charge1"]
natoms0, natoms1 = row["natoms0"], row["natoms1"]
pos = np.array(list(map(float, row["xyz"].split()))).reshape(-1, 3)

elements = row["elements"].split()

atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elements]), axis=1)

charges = np.expand_dims(np.array([charge0] * natoms0 + [charge1] * natoms1), axis=1)

atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32)

energies = np.array(row[cls.energy_target_names].values).astype(np.float32)[None, :]

name = np.array([smiles0 + "." + smiles1])

subsets = []
Expand Down
2 changes: 1 addition & 1 deletion openqdc/datasets/interaction/des5m.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ class DES5M(DES370K):
__forces_unit__ = "kcal/mol/ang"

def read_raw_entries(self) -> List[Dict]:
return DES5M._read_raw_entries()
return super()._read_raw_entries()
8 changes: 1 addition & 7 deletions openqdc/datasets/interaction/dess66.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,18 @@ def read_raw_entries(self) -> List[Dict]:
logger.info(f"Reading DESS66 interaction data from {self.filepath}")
df = pd.read_csv(self.filepath)
data = []
for idx, row in tqdm(df.iterrows(), total=df.shape[0]):
for _, row in tqdm(df.iterrows(), total=df.shape[0]):
smiles0, smiles1 = row["smiles0"], row["smiles1"]
charge0, charge1 = row["charge0"], row["charge1"]
natoms0, natoms1 = row["natoms0"], row["natoms1"]
pos = np.array(list(map(float, row["xyz"].split()))).reshape(-1, 3)

elements = row["elements"].split()

atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elements]), axis=1)

charges = np.expand_dims(np.array([charge0] * natoms0 + [charge1] * natoms1), axis=1)

atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32)

energies = np.array(row[self.energy_target_names].values).astype(np.float32)[None, :]

name = np.array([smiles0 + "." + smiles1])

subset = row["system_name"]

item = dict(
Expand Down
6 changes: 0 additions & 6 deletions openqdc/datasets/interaction/dess66x8.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,11 @@ def read_raw_entries(self) -> List[Dict]:
pos = np.array(list(map(float, row["xyz"].split()))).reshape(-1, 3)

elements = row["elements"].split()

atomic_nums = np.expand_dims(np.array([ATOM_TABLE.GetAtomicNumber(x) for x in elements]), axis=1)

charges = np.expand_dims(np.array([charge0] * natoms0 + [charge1] * natoms1), axis=1)

atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32)

energies = np.array(row[self.energy_target_names].values).astype(np.float32)[None, :]

name = np.array([smiles0 + "." + smiles1])

subset = row["system_name"]

item = dict(
Expand Down
5 changes: 1 addition & 4 deletions openqdc/datasets/interaction/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ class DummyInteraction(BaseInteractionDataset):

__name__ = "dummy_interaction"
__energy_methods__ = [InteractionMethod.SAPT0_AUG_CC_PVDDZ, InteractionMethod.CCSD_T_CC_PVDZ]
__force_mask__ = [False, True]
__force_mask__ = [False, False]
__energy_unit__ = "kcal/mol"
__distance_unit__ = "ang"
__forces_unit__ = "kcal/mol/ang"

energy_target_names = [f"energy{i}" for i in range(len(__energy_methods__))]

force_target_names = [f"forces{i}" for i in range(len(__force_mask__))]
__isolated_atom_energies__ = []
__average_n_atoms__ = None

Expand Down Expand Up @@ -48,7 +47,6 @@ def setup_dummy(self):
name = [f"dummy_{i}" for i in range(len(self))]
subset = ["dummy" for i in range(len(self))]
energies = np.random.rand(len(self), len(self.energy_methods))
forces = np.concatenate([np.random.randn(size, 3, len(self.force_methods)) * 100 for size in n_atoms])
self.data = dict(
n_atoms=n_atoms,
position_idx_range=position_idx_range,
Expand All @@ -57,7 +55,6 @@ def setup_dummy(self):
subset=subset,
energies=energies,
n_atoms_first=n_atoms_first,
forces=forces,
)
self.__average_nb_atoms__ = self.data["n_atoms"].mean()

Expand Down
3 changes: 2 additions & 1 deletion openqdc/datasets/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def to_dict(self):

def transform(self, func):
for k, v in self.to_dict().items():
setattr(self, k, func(v))
if v is not None:
setattr(self, k, func(v))


@dataclass
Expand Down
31 changes: 21 additions & 10 deletions tests/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
from openqdc.utils.io import get_local_cache
from openqdc.utils.package_utils import has_package


# start by removing any cached data
cache_dir = get_local_cache()
os.system(f"rm -rf {cache_dir}/dummy")
os.system(f"rm -rf {cache_dir}/dummy_interaction")
@pytest.fixture(autouse=True)
def clean_before_run():
# start by removing any cached data
cache_dir = get_local_cache()
os.system(f"rm -rf {cache_dir}/dummy")
os.system(f"rm -rf {cache_dir}/dummy_interaction")
yield


if has_package("torch"):
Expand Down Expand Up @@ -62,12 +67,15 @@ def test_dummy_array_format(interaction_ds, format):
"energies",
"forces",
"e0",
"formation_energies",
"per_atom_formation_energies",
]
if not interaction_ds:
# additional keys returned from the potential dataset
keys.extend(["formation_energies", "per_atom_formation_energies"])

data = ds[0]
for key in keys:
if data[key] is None:
continue
assert isinstance(data[key], format_to_type[format])


Expand Down Expand Up @@ -125,11 +133,12 @@ def test_force_statistics_shapes(ds, request):
keys = ["mean", "std", "component_mean", "component_std", "component_rms"]
assert all(k in forces_stats for k in keys)

assert forces_stats["mean"].shape == (1, num_force_methods)
assert forces_stats["std"].shape == (1, num_force_methods)
assert forces_stats["component_mean"].shape == (3, num_force_methods)
assert forces_stats["component_std"].shape == (3, num_force_methods)
assert forces_stats["component_rms"].shape == (3, num_force_methods)
if len(ds.force_methods) > 0:
assert forces_stats["mean"].shape == (1, num_force_methods)
assert forces_stats["std"].shape == (1, num_force_methods)
assert forces_stats["component_mean"].shape == (3, num_force_methods)
assert forces_stats["component_std"].shape == (3, num_force_methods)
assert forces_stats["component_rms"].shape == (3, num_force_methods)


@pytest.mark.parametrize("interaction_ds", [False, True])
Expand All @@ -143,4 +152,6 @@ def test_stats_array_format(interaction_ds, format):

for key in stats.keys():
for k, v in stats[key].items():
if v is None:
continue
assert isinstance(v, format_to_type[format])

0 comments on commit ed73e7d

Please sign in to comment.