From 261b8fb5e851aedba8fa85c05c7d3f5ea95c7a69 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 28 Nov 2023 13:13:02 -0500 Subject: [PATCH] Simplify state conversions in CCrystal. --- gflownet/envs/crystals/ccrystal.py | 100 ++++++++------------------- tests/gflownet/envs/test_ccrystal.py | 18 ++--- 2 files changed, 38 insertions(+), 80 deletions(-) diff --git a/gflownet/envs/crystals/ccrystal.py b/gflownet/envs/crystals/ccrystal.py index 0336d8a43..c6eadb364 100644 --- a/gflownet/envs/crystals/ccrystal.py +++ b/gflownet/envs/crystals/ccrystal.py @@ -857,100 +857,58 @@ def get_logprobs( ) return logprobs - def state2policy(self, state: Optional[List[int]] = None) -> Tensor: - """ - Prepares one state in "GFlowNet format" for the policy. Simply - a concatenation of all crystal components. - """ - state = self._get_state(state) - return self.statetorch2policy( - torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0) - )[0] - - def statebatch2policy( - self, states: List[List] + def states2policy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_policy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the policy. Simply + Prepares a batch of states in "environment format" for the policy model: simply a concatenation of all crystal components. - """ - return self.statetorch2policy( - tfloat(states, device=self.device, float_type=self.float) - ) - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_policy_dim"]: - """ - Prepares a tensor batch of states in "GFlowNet format" for the policy. Simply - a concatenation of all crystal components. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ + states = tfloat(states, device=self.device, float_type=self.float) return torch.cat( [ - subenv.statetorch2policy(self._get_states_of_subenv(states, stage)) + subenv.states2policy(self._get_states_of_subenv(states, stage)) for stage, subenv in self.subenvs.items() ], dim=1, ) - def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: - """ - Prepares one state in "GFlowNet format" for the oracle. Simply - a concatenation of all crystal components. - """ - state = self._get_state(state) - return self.statetorch2oracle( - torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0) - ) - - def statebatch2oracle( - self, states: List[List] + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_oracle_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the oracle. Simply - a concatenation of all crystal components. - """ - return self.statetorch2oracle( - tfloat(states, device=self.device, float_type=self.float) - ) + Prepares a batch of states in "environment format" for a proxy: simply a + concatenation of all crystal components. - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares one state in "GFlowNet format" for the oracle. Simply - a concatenation of all crystal components. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ + states = tfloat(states, device=self.device, float_type=self.float) return torch.cat( [ - subenv.statetorch2oracle(self._get_states_of_subenv(states, stage)) + subenv.states2oracle(self._get_states_of_subenv(states, stage)) for stage, subenv in self.subenvs.items() ], dim=1, ) - def state2proxy(self, state: Optional[List[int]] = None) -> Tensor: - """ - Returns state2oracle(state). - """ - return self.state2oracle(state) - - def statebatch2proxy( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Returns statebatch2oracle(states). - """ - return self.statebatch2oracle(states) - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Returns statetorch2oracle(states). - """ - return self.statetorch2oracle(states) - def set_state(self, state: List, done: Optional[bool] = False): super().set_state(state, done) diff --git a/tests/gflownet/envs/test_ccrystal.py b/tests/gflownet/envs/test_ccrystal.py index e88a348ae..4d0cdf00b 100644 --- a/tests/gflownet/envs/test_ccrystal.py +++ b/tests/gflownet/envs/test_ccrystal.py @@ -152,22 +152,22 @@ def test__pad_depad_action(env): ], ], ) -def test__statetorch2policy__is_concatenation_of_subenv_states(env, states): +def test__states2policy__is_concatenation_of_subenv_states(env, states): # Get policy states from the batch of states converted into each subenv states_dict = {stage: [] for stage in env.subenvs} for state in states: for stage in env.subenvs: states_dict[stage].append(env._get_state_of_subenv(state, stage)) states_policy_dict = { - stage: subenv.statebatch2policy(states_dict[stage]) + stage: subenv.states2policy(states_dict[stage]) for stage, subenv in env.subenvs.items() } states_policy_expected = torch.cat( [el for el in states_policy_dict.values()], dim=1 ) - # Get policy states from env.statetorch2policy + # Get policy states from env.states2policy states_torch = tfloat(states, float_type=env.float, device=env.device) - states_policy = env.statetorch2policy(states_torch) + states_policy = env.states2policy(states_torch) assert torch.all(torch.eq(states_policy, states_policy_expected)) @@ -191,20 +191,20 @@ def test__statetorch2policy__is_concatenation_of_subenv_states(env, states): ], ], ) -def test__statetorch2proxy__is_concatenation_of_subenv_states(env, states): +def test__states2proxy__is_concatenation_of_subenv_states(env, states): # Get proxy states from the batch of states converted into each subenv states_dict = {stage: [] for stage in env.subenvs} for state in states: for stage in env.subenvs: states_dict[stage].append(env._get_state_of_subenv(state, stage)) states_proxy_dict = { - stage: subenv.statebatch2proxy(states_dict[stage]) + stage: subenv.states2proxy(states_dict[stage]) for stage, subenv in env.subenvs.items() } states_proxy_expected = torch.cat([el for el in states_proxy_dict.values()], dim=1) - # Get proxy states from env.statetorch2proxy + # Get proxy states from env.states2proxy states_torch = tfloat(states, float_type=env.float, device=env.device) - states_proxy = env.statetorch2proxy(states_torch) + states_proxy = env.states2proxy(states_torch) assert torch.all(torch.eq(states_proxy, states_proxy_expected)) @@ -243,7 +243,7 @@ def test__state2readable__is_concatenation_of_subenv_states(env, states): f"SpaceGroup = {readables[1]}; " f"LatticeParameters = {readables[2]}" ) - # Get policy states from env.statetorch2policy + # Get policy states from env.states2policy states_readable = [env.state2readable(state) for state in states] for readable, readable_expected in zip(states_readable, states_readable_expected): assert readable == readable_expected