Skip to content

Commit 611b3f1

Browse files
ds bug fix
1 parent f75bd4c commit 611b3f1

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

apax/train/run.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,6 @@ def initialize_dataset(config, raw_ds, calc_stats: bool = True):
8585
energy_unit=config.data.energy_unit,
8686
)
8787

88-
dataset = TFPipeline(
89-
inputs,
90-
labels,
91-
config.n_epochs,
92-
config.data.batch_size,
93-
buffer_size=config.data.shuffle_buffer_size,
94-
)
95-
9688
if calc_stats:
9789
ds_stats = compute_scale_shift_parameters(
9890
inputs,
@@ -103,6 +95,15 @@ def initialize_dataset(config, raw_ds, calc_stats: bool = True):
10395
config.data.scale_options,
10496
)
10597

98+
dataset = TFPipeline(
99+
inputs,
100+
labels,
101+
config.n_epochs,
102+
config.data.batch_size,
103+
buffer_size=config.data.shuffle_buffer_size,
104+
)
105+
106+
if calc_stats:
106107
return dataset, ds_stats
107108
else:
108109
return dataset

apax/utils/convert.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def atoms_to_arrays(
109109

110110
inputs["ragged"]["numbers"].append(atoms.numbers)
111111
inputs["fixed"]["n_atoms"].append(len(atoms))
112-
113112
for key, val in atoms.calc.results.items():
114113
if key == "forces":
115114
labels["ragged"][key].append(

tests/unit_tests/data/test_statistics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_energy_per_element():
2424
atoms.calc = SinglePointCalculator(atoms, energy=energy)
2525

2626
labels = {
27-
"ragged": {
27+
"fixed": {
2828
"energy": [atoms.get_potential_energy() for atoms in atoms_list],
2929
}
3030
}

0 commit comments

Comments
 (0)