Skip to content

Commit

Permalink
Merge pull request #105 from sparks-baird/rgb-kwarg-tests-fix
Browse files Browse the repository at this point in the history
Add tests for rgb_scaling=False
  • Loading branch information
sgbaird authored Jun 17, 2022
2 parents 1776dae + b33a42f commit 42c6a46
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/xtal2png/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def png2xtal(
def structures_to_arrays(
self,
structures: Sequence[Structure],
rgb_output=True,
rgb_scaling=True,
):
"""Convert pymatgen Structure to scaled 3D array of crystallographic info.
Expand Down Expand Up @@ -587,7 +587,7 @@ def structures_to_arrays(
frac_coords = np.stack(frac_coords_tmp)
distance_matrix = np.stack(distance_matrix_tmp)

if rgb_output:
if rgb_scaling:
# REVIEW: consider using modified pettifor scale instead of atomic numbers
# REVIEW: consider using feature_range=atom_range or 2*atom_range
# REVIEW: since it introduces a sort of non-linearity b.c. of rounding
Expand Down Expand Up @@ -636,7 +636,9 @@ def structures_to_arrays(
volume, feature_range=feature_range, data_range=self.volume_range
)
space_group_scaled = element_wise_scaler(
space_group, data_range=self.space_group_range
space_group,
feature_range=feature_range,
data_range=self.space_group_range,
)
distance_scaled = element_wise_scaler(
distance_matrix,
Expand Down Expand Up @@ -821,7 +823,7 @@ def arrays_to_structures(
data: np.ndarray,
id_data: Optional[np.ndarray] = None,
id_mapper: Optional[dict] = None,
rgb_output: bool = True,
rgb_scaling: bool = True,
):
"""Convert scaled crystal (xtal) arrays to pymatgen Structures.
Expand Down Expand Up @@ -867,7 +869,7 @@ def arrays_to_structures(
int
)

if rgb_output:
if rgb_scaling:
atomic_numbers = rgb_unscaler(atom_scaled, data_range=self.atom_range)
frac_coords = rgb_unscaler(frac_scaled, data_range=self.frac_range)
latt_a = rgb_unscaler(a_scaled, data_range=self.a_range)
Expand Down
28 changes: 28 additions & 0 deletions tests/xtal2png_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from warnings import warn

import numpy as np
import plotly.express as px
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher
Expand Down Expand Up @@ -125,6 +126,21 @@ def test_structures_to_arrays_single():
return data


def test_structures_to_arrays_zero_one():
xc = XtalConverter(relax_on_decode=False)
data, _, _ = xc.structures_to_arrays(example_structures, rgb_scaling=False)

if np.min(data) < 0.0:
raise ValueError(
f"minimum is less than 0 when rgb_output=False: {np.min(data)}"
)
if np.max(data) > 1.0:
raise ValueError(
f"maximum is greater than 1 when rgb_output=False: {np.max(data)}"
)
return data


def test_arrays_to_structures():
xc = XtalConverter(relax_on_decode=False)
data, id_data, id_mapper = xc.structures_to_arrays(example_structures)
Expand All @@ -133,6 +149,16 @@ def test_arrays_to_structures():
return structures


def test_arrays_to_structures_zero_one():
xc = XtalConverter(relax_on_decode=False)
data, id_data, id_mapper = xc.structures_to_arrays(
example_structures, rgb_scaling=False
)
structures = xc.arrays_to_structures(data, id_data, id_mapper, rgb_scaling=False)
assert_structures_approximate_match(example_structures, structures)
return structures


def test_arrays_to_structures_single():
xc = XtalConverter(relax_on_decode=False)
data, id_data, id_mapper = xc.structures_to_arrays([example_structures[0]])
Expand Down Expand Up @@ -295,6 +321,8 @@ def test_plot_and_save():


if __name__ == "__main__":
test_structures_to_arrays_zero_one()
test_arrays_to_structures_zero_one()
test_relax_on_decode()
test_primitive_decoding()
test_primitive_encoding()
Expand Down

0 comments on commit 42c6a46

Please sign in to comment.