diff --git a/leap_net/proxy/proxyLeapNet.py b/leap_net/proxy/proxyLeapNet.py index 03b2b81..6f6de1c 100644 --- a/leap_net/proxy/proxyLeapNet.py +++ b/leap_net/proxy/proxyLeapNet.py @@ -1093,13 +1093,32 @@ def _given_list_topo_encode(self, obs): continue # so i have a different topology that the reference one - lookup = (sub_id, tuple([el if el >= 1 else 1 for el in this_sub_topo])) - if lookup in self.dict_topo: - res[self.dict_topo[lookup]] = 1. + #lookup = (sub_id, tuple([el if el >= 1 else 1 for el in this_sub_topo])) + topo_found=False + + if (np.all(conn)): + lookup = (sub_id, tuple([el if el >= 1 else 1 for el in this_sub_topo])) + if lookup in self.dict_topo: + res[self.dict_topo[lookup]] = 1. + topo_found = True else: + for topo_sub in self.dict_topo: + if topo_sub[0] == sub_id: + topo=np.array(topo_sub[1]) + if np.all(topo[conn]==this_sub_topo[conn]): + res[self.dict_topo[topo_sub]] = 1. + topo_found=True + conn_elements_bus_bar_2=conn[(topo_found==2)] + break + + #if lookup in self.dict_topo: + # res[self.dict_topo[lookup]] = 1. + #else: + if not topo_found: warnings.warn(f"Topology {lookup} is not found on the topo dictionary") return res + def _online_list_topo_encode(self, obs): """ This method behaves exaclyt like :func:`ProxyLeapNet._given_list_topo_encode` with one difference: you do not diff --git a/leap_net/test/test_LeapNetProxy.py b/leap_net/test/test_LeapNetProxy.py index 06007b9..6e5df34 100644 --- a/leap_net/test/test_LeapNetProxy.py +++ b/leap_net/test/test_LeapNetProxy.py @@ -215,7 +215,7 @@ def _aux_test_tau_from_list_topo(self, proxy=None): proxy = ProxyLeapNet(attr_tau=("line_status", "topo_vect",), topo_vect_to_tau="given_list", kwargs_tau=[(0, (2, 1, 1)), (0, (1, 2, 1)), (1, (2, 1, 1, 1, 1, 1)), - (12, (2, 1, 1, 2)), (13, (2, 1, 2)), (13, (1, 2, 2))] + (12, (2, 1, 1, 2)), (13, (2, 1, 2)), (13, (1, 2, 2)), (1, (2, 1, 2, 1, 2, 1))] ) proxy.init([self.obs]) @@ -296,6 +296,23 @@ def _aux_test_tau_from_list_topo(self, proxy=None): assert np.sum(res) == 1 assert res[2] == 1. + # test that if a line is disconnected, we are still able to match the topologies + env = self.env + obs = env.reset() + act = env.action_space({"set_bus": {"substations_id": [(1, (2, 1, -1, 1, 2, 1))]}})#(2, 1, 2, 1, 2, 1) + obs, reward, done, info = env.step(act) + res = proxy.topo_vect_handler(obs) + assert np.sum(res) == 1 + assert res[6] == 1. + + env = self.env + obs = env.reset() + act = env.action_space({"set_bus": {"substations_id": [(1, (2, -1, 2, 1, 2, 1))]}}) + obs, reward, done, info = env.step(act) + res = proxy.topo_vect_handler(obs) + assert np.sum(res) == 1 + assert res[6] == 1. + def test_tau_from_online_topo(self): self._aux_test_tau_from_online_topo()