diff --git a/ml4co_kit/generator/cvrp_data.py b/ml4co_kit/generator/cvrp_data.py index e581514..40c65e4 100644 --- a/ml4co_kit/generator/cvrp_data.py +++ b/ml4co_kit/generator/cvrp_data.py @@ -18,7 +18,7 @@ def __init__( num_threads: int = 1, nodes_num: int = 50, data_type: str = "uniform", - solver: Union[str, CVRPSolver] = "pyvrp", + solver: Union[str, CVRPSolver] = "PyVRP", train_samples_num: int = 128000, val_samples_num: int = 1280, test_samples_num: int = 1280, @@ -115,9 +115,9 @@ def check_solver(self): if type(self.solver) == str: self.solver_type = self.solver supported_solver_dict = { - "pyvrp": CVRPPyVRPSolver, - "lkh": CVRPLKHSolver, - "hgs": CVRPHGSSolver + "PyVRP": CVRPPyVRPSolver, + "LKH": CVRPLKHSolver, + "HGS": CVRPHGSSolver } supported_solver_type = supported_solver_dict.keys() if self.solver_type not in supported_solver_type: @@ -132,9 +132,9 @@ def check_solver(self): self.solver_type = self.solver.solver_type # check solver check_solver_dict = { - "lkh": self.check_lkh, - "pyvrp": self.check_free, - "hgs": self.check_free + "PyVRP": self.check_free, + "LKH": self.check_lkh, + "HGS": self.check_free } check_func = check_solver_dict[self.solver_type] check_func() @@ -194,25 +194,38 @@ def get_filename(self): os.makedirs(self.save_path) def generate(self): - with open(self.file_save_path, "w") as f: - start_time = time.time() - for _ in tqdm( - range(self.samples_num // self.num_threads), - desc=f"Solving CVRP Using {self.solver_type}", - ): - batch_depots_coord, batch_nodes_coord= self.generate_func() - batch_demands = self.generate_demands() - batch_capacities = self.generate_capacities() + start_time = time.time() + for _ in tqdm( + range(self.samples_num // self.num_threads), + desc=f"Solving CVRP Using {self.solver_type}", + ): + # call generate_func to generate the points + batch_depots_coord, batch_nodes_coord = self.generate_func() + batch_demands = self.generate_demands() + batch_capacities = self.generate_capacities() + + # solve + if self.num_threads == 1: + tours = self.solver.solve( + depots=batch_depots_coord[0], + points=batch_nodes_coord[0], + demands=batch_demands[0], + capacities=batch_capacities[0] + ) + tours = [tours] + else: with Pool(self.num_threads) as p1: tours = p1.starmap( self.solver.solve, [(batch_depots_coord[idx], - batch_nodes_coord[idx], - batch_demands[idx], - batch_capacities[idx]) - for idx in range(self.num_threads)], + batch_nodes_coord[idx], + batch_demands[idx], + batch_capacities[idx]) + for idx in range(self.num_threads)], ) - # write to txt + + # write to txt + with open(self.file_save_path, "w") as f: for idx, tour in enumerate(tours): depot = batch_depots_coord[idx] points = batch_nodes_coord[idx] @@ -231,13 +244,15 @@ def generate(self): f.write(str(" output ")) f.write(str(" ").join(str(node_idx) for node_idx in tour[0])) f.write("\n") - end_time = time.time() - start_time f.close() - print( - f"Completed generation of {self.samples_num} samples of CVRP{self.nodes_num}." - ) - print(f"Total time: {end_time/60:.1f}m") - print(f"Average time: {end_time/self.samples_num:.1f}s") + + # info + end_time = time.time() - start_time + print( + f"Completed generation of {self.samples_num} samples of CVRP{self.nodes_num}." + ) + print(f"Total time: {end_time/60:.1f}m") + print(f"Average time: {end_time/self.samples_num:.1f}s") self.devide_file() def devide_file(self): diff --git a/ml4co_kit/generator/mis_data.py b/ml4co_kit/generator/mis_data.py index f9b65c6..2ceb378 100644 --- a/ml4co_kit/generator/mis_data.py +++ b/ml4co_kit/generator/mis_data.py @@ -135,7 +135,7 @@ def check_solver(self): # check solver if type(self.solver) == str: self.solver_type = self.solver - supported_solver_dict = {"kamis": KaMISSolver, "gurobi": MISGurobiSolver} + supported_solver_dict = {"KaMIS": KaMISSolver, "Gurobi": MISGurobiSolver} supported_solver_type = supported_solver_dict.keys() if self.solver not in supported_solver_type: message = ( diff --git a/ml4co_kit/generator/tsp_data.py b/ml4co_kit/generator/tsp_data.py index fe25f99..4246ccc 100644 --- a/ml4co_kit/generator/tsp_data.py +++ b/ml4co_kit/generator/tsp_data.py @@ -25,7 +25,7 @@ def __init__( num_threads: int = 1, nodes_num: int = 50, data_type: str = "uniform", - solver: Union[str, TSPSolver] = "lkh", + solver: Union[str, TSPSolver] = "LKH", train_samples_num: int = 128000, val_samples_num: int = 1280, test_samples_num: int = 1280, @@ -128,11 +128,11 @@ def check_solver(self): if type(self.solver) == str: self.solver_type = self.solver supported_solver_dict = { - "lkh": TSPLKHSolver, - "concorde": TSPConcordeSolver, - "concorde-large": TSPConcordeLargeSolver, - "ga-eax": TSPGAEAXSolver, - "ga-eax-large": TSPGAEAXLargeSolver + "LKH": TSPLKHSolver, + "Concorde": TSPConcordeSolver, + "Concorde-Large": TSPConcordeLargeSolver, + "GA-EAX": TSPGAEAXSolver, + "GA-EAX-Large": TSPGAEAXLargeSolver } supported_solver_type = supported_solver_dict.keys() if self.solver_type not in supported_solver_type: @@ -147,11 +147,11 @@ def check_solver(self): self.solver_type = self.solver.solver_type # check solver check_solver_dict = { - "lkh": self.check_lkh, - "concorde": self.check_concorde, - "concorde-large": self.check_concorde, - "ga-eax": self.check_free, - "ga-eax-large": self.check_free + "LKH": self.check_lkh, + "Concorde": self.check_concorde, + "Concorde-Large": self.check_concorde, + "GA-EAX": self.check_free, + "GA-EAX-Large": self.check_free } check_func = check_solver_dict[self.solver_type] check_func() @@ -241,7 +241,9 @@ def generate(self): desc=f"Solving TSP Using {self.solver_type}", ): # call generate_func to generate the points - batch_nodes_coord = self.generate_func() + batch_nodes_coord = self.generate_func() + + # solve if self.num_threads == 1: tours = [self.solver.solve(batch_nodes_coord[0])] else: @@ -250,11 +252,11 @@ def generate(self): self.solver.solve, [batch_nodes_coord[idx] for idx in range(self.num_threads)], ) - - # deal with regret - if self.regret: p1.close() # Close the pool to indicate that no more tasks will be submitted p1.join() # Wait for all processes in the pool to complete + + # deal with regret + if self.regret: if self.num_threads == 1: self.generate_regret(tours[0], batch_nodes_coord[0], cnt) else: diff --git a/tests/test_generator.py b/tests/test_generator.py index 512cf4b..a7e6915 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -28,7 +28,7 @@ def _test_tsp_lkh_generator( num_threads=num_threads, nodes_num=nodes_num, data_type=data_type, - solver="lkh", + solver="LKH", train_samples_num=4, val_samples_num=4, test_samples_num=4, @@ -59,7 +59,7 @@ def _test_tsp_concorde_generator( num_threads=num_threads, nodes_num=nodes_num, data_type=data_type, - solver="concorde", + solver="Concorde", train_samples_num=4, val_samples_num=4, test_samples_num=4, @@ -90,7 +90,7 @@ def _test_tsp_concorde_large_generator( num_threads=num_threads, nodes_num=nodes_num, data_type=data_type, - solver="concorde-large", + solver="Concorde-Large", train_samples_num=1, val_samples_num=0, test_samples_num=0, @@ -120,7 +120,7 @@ def _test_tsp_ga_eax_generator( num_threads=num_threads, nodes_num=nodes_num, data_type=data_type, - solver="ga-eax", + solver="GA-EAX", train_samples_num=4, val_samples_num=4, test_samples_num=4, @@ -148,7 +148,7 @@ def _test_tsp_ga_eax_large_generator( num_threads=num_threads, nodes_num=nodes_num, data_type=data_type, - solver="ga-eax-large", + solver="GA-EAX-Large", train_samples_num=1, val_samples_num=0, test_samples_num=0, @@ -254,7 +254,7 @@ def _test_mis_gurobi( nodes_num_min=nodes_num_min, nodes_num_max=nodes_num_max, data_type=data_type, - solver="gurobi", + solver="Gurobi", train_samples_num=2, val_samples_num=2, test_samples_num=2, @@ -336,7 +336,7 @@ def _test_cvrp_lkh_generator( num_threads=num_threads, nodes_num=nodes_num, data_type=data_type, - solver="lkh", + solver="LKH", train_samples_num=4, val_samples_num=4, test_samples_num=4, @@ -365,7 +365,7 @@ def _test_cvrp_hgs_generator( num_threads=num_threads, nodes_num=nodes_num, data_type=data_type, - solver="hgs", + solver="HGS", train_samples_num=4, val_samples_num=4, test_samples_num=4,