Skip to content

Commit

Permalink
Delete pickle identifier and species in torch dataset since the whole…
Browse files Browse the repository at this point in the history
… configuraion is pickled
  • Loading branch information
Mingjian Wen committed Aug 18, 2019
1 parent d2724ed commit b3083c1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 17 deletions.
14 changes: 1 addition & 13 deletions kliff/descriptors/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import kliff
from .. import parallel
from ..log import log_entry
from ..atomic_data import atomic_number

logger = kliff.logger.get_logger(__name__)

Expand Down Expand Up @@ -235,10 +234,6 @@ def dump_fingerprints(
dzetadr_s = dzetadr_s / stdev_3d

# pickling data
identifier = conf.get_identifier()
species = conf.get_species()
species = np.asarray([atomic_number[i] for i in species], np.intc)
weight = np.asarray(conf.get_weight(), self.dtype)
zeta = np.asarray(zeta, self.dtype)
energy = np.asarray(conf.get_energy(), self.dtype)
if fit_forces:
Expand All @@ -249,14 +244,7 @@ def dump_fingerprints(
stress = np.asarray(conf.get_stress(), self.dtype)
volume = np.asarray(conf.get_volume(), self.dtype)

example = {
'configuration': conf,
'identifier': identifier,
'species': species,
'weight': weight,
'zeta': zeta,
'energy': energy,
}
example = {'configuration': conf, 'zeta': zeta, 'energy': energy}
if fit_forces:
example['dzetadr_forces'] = dzetadr_f
example['forces'] = forces
Expand Down
8 changes: 4 additions & 4 deletions kliff/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,10 +774,10 @@ def get_loss_single_config(self, sample, pred_energy, pred_forces, pred_stress):
pred = pred_stress.reshape(-1)
ref = ref_stress.reshape(-1)

identifier = sample['identifier']
species = sample['species']
weight = sample['weight']
natoms = len(species)
conf = sample['configuration']
identifier = conf.get_identifier()
natoms = conf.get_number_of_atoms()
weight = conf.get_weight()

residual = self.residual_fn(
identifier, natoms, weight, pred, ref, self.residual_data
Expand Down

0 comments on commit b3083c1

Please sign in to comment.