From 67fadfdc05feeaa2b7eb4381e41f9c5bff889be5 Mon Sep 17 00:00:00 2001 From: heatingma Date: Sat, 16 Nov 2024 16:51:29 +0800 Subject: [PATCH] Fix bugs --- README.md | 2 +- ml4co_kit/solver/cvrp/base.py | 4 ++-- ml4co_kit/solver/tsp/base.py | 2 +- ml4co_kit/solver/tsp/c_ga_eax_large/__init__.py | 1 + tests/test_draw.py | 10 +++++----- tests/test_solver.py | 6 +++--- 6 files changed, 13 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index bab960c..1d5b1cc 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ tsp_data_lkh = TSPDataGenerator( num_threads=8, nodes_num=50, data_type="uniform", - solver="lkh", + solver="LKH", train_samples_num=16, val_samples_num=16, test_samples_num=16, diff --git a/ml4co_kit/solver/cvrp/base.py b/ml4co_kit/solver/cvrp/base.py index 1bc4d05..170efc9 100644 --- a/ml4co_kit/solver/cvrp/base.py +++ b/ml4co_kit/solver/cvrp/base.py @@ -264,7 +264,7 @@ def _read_data_from_vrp_file(self, vrp_file_path: str, round_func: str): points_list.append([client.x, client.y]) demands_list.append(client.demand if CP38 else client.delivery) points = np.array(points_list) - demands = np.array(demands_list) + demands = np.array(demands_list).reshape(-1) # capacity capacity = _vehicle_types.capacity @@ -650,7 +650,7 @@ def to_vrplib_folder( # demands and capacities need be int demands = demands.astype(np.int32) capacities = capacities.astype(np.int32) - + # .vrp files if vrp_save_dir is not None: # filename diff --git a/ml4co_kit/solver/tsp/base.py b/ml4co_kit/solver/tsp/base.py index 1a76c51..6ccc9da 100644 --- a/ml4co_kit/solver/tsp/base.py +++ b/ml4co_kit/solver/tsp/base.py @@ -404,7 +404,7 @@ def to_tsplib_folder( # write with open(save_path, "w") as f: - f.write(f"NAME : {name}") + f.write(f"NAME : {name}\n") f.write(f"COMMENT : Generated by ML4CO-Kit\n") f.write("TYPE : TSP\n") f.write(f"DIMENSION : {self.nodes_num}\n") diff --git a/ml4co_kit/solver/tsp/c_ga_eax_large/__init__.py b/ml4co_kit/solver/tsp/c_ga_eax_large/__init__.py index 6377ef6..c4b9e8d 100644 --- a/ml4co_kit/solver/tsp/c_ga_eax_large/__init__.py +++ b/ml4co_kit/solver/tsp/c_ga_eax_large/__init__.py @@ -23,6 +23,7 @@ def tsp_ga_eax_large_solve( max_trials: int, sol_name: str, population_num: int, offspring_num: int, tsp_name: str, show_info: bool = False ): + show_info = 1 if show_info else 0 tsp_path = os.path.join("tmp", tsp_name) sol_path = os.path.join("tmp", sol_name) ori_dir = os.getcwd() diff --git a/tests/test_draw.py b/tests/test_draw.py index ffaef0b..8e700e3 100644 --- a/tests/test_draw.py +++ b/tests/test_draw.py @@ -144,9 +144,9 @@ def test_draw_tsp(): ############################################## if __name__ == "__main__": - # test_draw_cvrp() - # test_draw_mcl() - # test_draw_mcut() - # test_draw_mis() + test_draw_cvrp() + test_draw_mcl() + test_draw_mcut() + test_draw_mis() test_draw_mvc() - # test_draw_tsp() + test_draw_tsp() diff --git a/tests/test_solver.py b/tests/test_solver.py index 164503e..9e1685f 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -551,10 +551,10 @@ def test_tsp(): ############################################## if __name__ == "__main__": - test_tsp() + test_atsp() + test_cvrp() test_mcl() test_mcut() test_mis() test_mvc() - test_cvrp() - test_atsp() + test_tsp()