Skip to content

Commit af9f23e

Browse files
committed
add tests for new lattice param env
1 parent ea3b03c commit af9f23e

File tree

1 file changed

+310
-0
lines changed

1 file changed

+310
-0
lines changed
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
import common
2+
import pytest
3+
import torch
4+
5+
from gflownet.envs.crystals.clattice_parameters import (
6+
CUBIC,
7+
HEXAGONAL,
8+
MONOCLINIC,
9+
ORTHORHOMBIC,
10+
PARAMETER_NAMES,
11+
RHOMBOHEDRAL,
12+
TETRAGONAL,
13+
TRICLINIC,
14+
CLatticeParametersSingleDimIncrement,
15+
)
16+
from gflownet.envs.crystals.lattice_parameters import LATTICE_SYSTEMS
17+
from gflownet.utils.common import tfloat
18+
19+
N_REPETITIONS = 100
20+
21+
22+
def values_are_source_or_distinct(values, source_value=-1):
23+
values_not_source = [v for v in values if v != source_value]
24+
return len({*values_not_source}) == len(values_not_source)
25+
26+
27+
@pytest.fixture()
28+
def env(lattice_system):
29+
return CLatticeParametersSingleDimIncrement(
30+
lattice_system=lattice_system,
31+
min_length=1.0,
32+
max_length=5.0,
33+
min_angle=30.0,
34+
max_angle=150.0,
35+
)
36+
37+
38+
@pytest.mark.parametrize("lattice_system", LATTICE_SYSTEMS)
39+
def test__environment__initializes_properly(env, lattice_system):
40+
pass
41+
42+
43+
@pytest.mark.parametrize(
44+
"lattice_system, expected_params",
45+
[
46+
(CUBIC, [None, None, None, 90, 90, 90]),
47+
(HEXAGONAL, [None, None, None, 90, 90, 120]),
48+
(MONOCLINIC, [None, None, None, 90, None, 90]),
49+
(ORTHORHOMBIC, [None, None, None, 90, 90, 90]),
50+
(RHOMBOHEDRAL, [None, None, None, None, None, None]),
51+
(TETRAGONAL, [None, None, None, 90, 90, 90]),
52+
(TRICLINIC, [None, None, None, None, None, None]),
53+
],
54+
)
55+
def test__environment__has_expected_fixed_parameters(
56+
env, lattice_system, expected_params
57+
):
58+
for expected_value, param_name in zip(expected_params, PARAMETER_NAMES):
59+
if expected_value is not None:
60+
assert getattr(env, param_name) == expected_value
61+
62+
63+
@pytest.mark.parametrize(
64+
"lattice_system",
65+
[CUBIC],
66+
)
67+
@pytest.mark.repeat(N_REPETITIONS)
68+
def test__cubic__constraints_remain_after_random_actions(env, lattice_system):
69+
env = env.reset()
70+
while not env.done:
71+
(a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles()
72+
assert len({a, b, c}) == 1
73+
assert len({alpha, beta, gamma, 90.0}) == 1
74+
env.step_random()
75+
76+
77+
@pytest.mark.parametrize(
78+
"lattice_system",
79+
[HEXAGONAL],
80+
)
81+
@pytest.mark.repeat(N_REPETITIONS)
82+
def test__hexagonal__constraints_remain_after_random_actions(env, lattice_system):
83+
env = env.reset()
84+
while not env.done:
85+
env.step_random()
86+
(a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles()
87+
assert a == b
88+
if c != -1:
89+
assert c != a
90+
assert len({alpha, beta, 90.0}) == 1
91+
assert gamma == 120.0
92+
93+
94+
@pytest.mark.parametrize(
95+
"lattice_system",
96+
[MONOCLINIC],
97+
)
98+
@pytest.mark.repeat(N_REPETITIONS)
99+
def test__monoclinic__constraints_remain_after_random_actions(env, lattice_system):
100+
env = env.reset()
101+
while not env.done:
102+
env.step_random()
103+
(a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles()
104+
values_are_source_or_distinct((a, b, c))
105+
assert len({alpha, gamma, 90.0}) == 1
106+
assert beta != 90.0
107+
108+
109+
@pytest.mark.parametrize(
110+
"lattice_system",
111+
[ORTHORHOMBIC],
112+
)
113+
@pytest.mark.repeat(N_REPETITIONS)
114+
def test__orthorhombic__constraints_remain_after_random_actions(env, lattice_system):
115+
env = env.reset()
116+
while not env.done:
117+
env.step_random()
118+
(a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles()
119+
values_are_source_or_distinct((a, b, c))
120+
assert len({alpha, beta, gamma, 90.0}) == 1
121+
122+
123+
@pytest.mark.parametrize(
124+
"lattice_system",
125+
[RHOMBOHEDRAL],
126+
)
127+
@pytest.mark.repeat(N_REPETITIONS)
128+
def test__rhombohedral__constraints_remain_after_random_actions(env, lattice_system):
129+
env = env.reset()
130+
while not env.done:
131+
env.step_random()
132+
(a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles()
133+
assert len({a, b, c}) == 1
134+
assert len({alpha, beta, gamma}) == 1
135+
assert len({alpha, beta, gamma, 90.0}) == 2
136+
137+
138+
@pytest.mark.parametrize(
139+
"lattice_system",
140+
[TETRAGONAL],
141+
)
142+
@pytest.mark.repeat(N_REPETITIONS)
143+
def test__tetragonal__constraints_remain_after_random_actions(env, lattice_system):
144+
env = env.reset()
145+
while not env.done:
146+
env.step_random()
147+
(a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles()
148+
values_are_source_or_distinct((a, c))
149+
values_are_source_or_distinct((b, c))
150+
assert len({alpha, beta, gamma, 90.0}) == 1
151+
152+
153+
@pytest.mark.parametrize(
154+
"lattice_system",
155+
[TRICLINIC],
156+
)
157+
@pytest.mark.repeat(N_REPETITIONS)
158+
def test__triclinic__constraints_remain_after_random_actions(env, lattice_system):
159+
env = env.reset()
160+
while not env.done:
161+
env.step_random()
162+
(a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles()
163+
values_are_source_or_distinct((a, b, c))
164+
values_are_source_or_distinct((alpha, beta, gamma, 90.0))
165+
166+
167+
@pytest.mark.parametrize(
168+
"lattice_system, states, states_proxy_expected",
169+
[
170+
(
171+
TRICLINIC,
172+
[
173+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
174+
[0.0, 0.2, 0.5, 0.0, 0.5, 1.0],
175+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
176+
],
177+
[
178+
[1.0, 1.0, 1.0, 30.0, 30.0, 30.0],
179+
[1.0, 1.8, 3.0, 30.0, 90.0, 150.0],
180+
[5.0, 5.0, 5.0, 150.0, 150.0, 150.0],
181+
],
182+
),
183+
(
184+
CUBIC,
185+
[
186+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
187+
[0.25, 0.5, 0.75, 0.25, 0.5, 0.75],
188+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
189+
],
190+
[
191+
[1.0, 1.0, 1.0, 30.0, 30.0, 30.0],
192+
[2.0, 3.0, 4.0, 60.0, 90.0, 120.0],
193+
[5.0, 5.0, 5.0, 150.0, 150.0, 150.0],
194+
],
195+
),
196+
],
197+
)
198+
def test__statetorch2proxy__returns_expected(
199+
env, lattice_system, states, states_proxy_expected
200+
):
201+
"""
202+
Various lattice systems are tried because the conversion should be independent of
203+
the lattice system, since the states are expected to satisfy the constraints.
204+
"""
205+
# Get policy states from the batch of states converted into each subenv
206+
# Get policy states from env.statetorch2policy
207+
states_torch = tfloat(states, float_type=env.float, device=env.device)
208+
states_proxy_expected_torch = tfloat(
209+
states_proxy_expected, float_type=env.float, device=env.device
210+
)
211+
states_proxy = env.statetorch2proxy(states_torch)
212+
assert torch.all(torch.eq(states_proxy, states_proxy_expected_torch))
213+
states_proxy = env.statebatch2proxy(states_torch)
214+
assert torch.all(torch.eq(states_proxy, states_proxy_expected_torch))
215+
216+
217+
@pytest.mark.parametrize(
218+
"lattice_system, states, states_policy_expected",
219+
[
220+
(
221+
TRICLINIC,
222+
[
223+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
224+
[0.0, 0.2, 0.5, 0.0, 0.5, 1.0],
225+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
226+
],
227+
[
228+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
229+
[0.0, 0.2, 0.5, 0.0, 0.5, 1.0],
230+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
231+
],
232+
),
233+
(
234+
CUBIC,
235+
[
236+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
237+
[0.25, 0.5, 0.75, 0.25, 0.5, 0.75],
238+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
239+
],
240+
[
241+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
242+
[0.25, 0.5, 0.75, 0.25, 0.5, 0.75],
243+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
244+
],
245+
),
246+
],
247+
)
248+
def test__statetorch2policy__returns_expected(
249+
env, lattice_system, states, states_policy_expected
250+
):
251+
"""
252+
Various lattice systems are tried because the conversion should be independent of
253+
the lattice system, since the states are expected to satisfy the constraints.
254+
"""
255+
# Get policy states from the batch of states converted into each subenv
256+
# Get policy states from env.statetorch2policy
257+
states_torch = tfloat(states, float_type=env.float, device=env.device)
258+
states_policy_expected_torch = tfloat(
259+
states_policy_expected, float_type=env.float, device=env.device
260+
)
261+
states_policy = env.statetorch2policy(states_torch)
262+
assert torch.all(torch.eq(states_policy, states_policy_expected_torch))
263+
states_policy = env.statebatch2policy(states_torch)
264+
assert torch.all(torch.eq(states_policy, states_policy_expected_torch))
265+
266+
267+
@pytest.mark.parametrize(
268+
"lattice_system, expected_output",
269+
[
270+
(CUBIC, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"),
271+
(HEXAGONAL, "(1.0, 1.0, 1.0), (90.0, 90.0, 120.0)"),
272+
(MONOCLINIC, "(1.0, 1.0, 1.0), (90.0, 30.0, 90.0)"),
273+
(ORTHORHOMBIC, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"),
274+
(RHOMBOHEDRAL, "(1.0, 1.0, 1.0), (30.0, 30.0, 30.0)"),
275+
(TETRAGONAL, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"),
276+
(TRICLINIC, "(1.0, 1.0, 1.0), (30.0, 30.0, 30.0)"),
277+
],
278+
)
279+
@pytest.mark.skip(reason="skip until it gets updated")
280+
def test__state2readable__gives_expected_results_for_initial_states(
281+
env, lattice_system, expected_output
282+
):
283+
assert env.state2readable() == expected_output
284+
285+
286+
@pytest.mark.parametrize(
287+
"lattice_system, readable",
288+
[
289+
(CUBIC, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"),
290+
(HEXAGONAL, "(1.0, 1.0, 1.0), (90.0, 90.0, 120.0)"),
291+
(MONOCLINIC, "(1.0, 1.0, 1.0), (90.0, 30.0, 90.0)"),
292+
(ORTHORHOMBIC, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"),
293+
(RHOMBOHEDRAL, "(1.0, 1.0, 1.0), (30.0, 30.0, 30.0)"),
294+
(TETRAGONAL, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"),
295+
(TRICLINIC, "(1.0, 1.0, 1.0), (30.0, 30.0, 30.0)"),
296+
],
297+
)
298+
@pytest.mark.skip(reason="skip until it gets updated")
299+
def test__readable2state__gives_expected_results_for_initial_states(
300+
env, lattice_system, readable
301+
):
302+
assert env.readable2state(readable) == env.state
303+
304+
305+
@pytest.mark.parametrize(
306+
"lattice_system",
307+
[CUBIC, HEXAGONAL, MONOCLINIC, ORTHORHOMBIC, RHOMBOHEDRAL, TETRAGONAL, TRICLINIC],
308+
)
309+
def test__continuous_env_common(env, lattice_system):
310+
return common.test__continuous_env_common(env)

0 commit comments

Comments
 (0)