Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
heatingma committed Jun 4, 2024
1 parent 4bd0b65 commit 6a1aad8
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 51 deletions.
69 changes: 42 additions & 27 deletions ml4co_kit/generator/cvrp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
num_threads: int = 1,
nodes_num: int = 50,
data_type: str = "uniform",
solver: Union[str, CVRPSolver] = "pyvrp",
solver: Union[str, CVRPSolver] = "PyVRP",
train_samples_num: int = 128000,
val_samples_num: int = 1280,
test_samples_num: int = 1280,
Expand Down Expand Up @@ -115,9 +115,9 @@ def check_solver(self):
if type(self.solver) == str:
self.solver_type = self.solver
supported_solver_dict = {
"pyvrp": CVRPPyVRPSolver,
"lkh": CVRPLKHSolver,
"hgs": CVRPHGSSolver
"PyVRP": CVRPPyVRPSolver,
"LKH": CVRPLKHSolver,
"HGS": CVRPHGSSolver
}
supported_solver_type = supported_solver_dict.keys()
if self.solver_type not in supported_solver_type:
Expand All @@ -132,9 +132,9 @@ def check_solver(self):
self.solver_type = self.solver.solver_type
# check solver
check_solver_dict = {
"lkh": self.check_lkh,
"pyvrp": self.check_free,
"hgs": self.check_free
"PyVRP": self.check_free,
"LKH": self.check_lkh,
"HGS": self.check_free
}
check_func = check_solver_dict[self.solver_type]
check_func()
Expand Down Expand Up @@ -194,25 +194,38 @@ def get_filename(self):
os.makedirs(self.save_path)

def generate(self):
with open(self.file_save_path, "w") as f:
start_time = time.time()
for _ in tqdm(
range(self.samples_num // self.num_threads),
desc=f"Solving CVRP Using {self.solver_type}",
):
batch_depots_coord, batch_nodes_coord= self.generate_func()
batch_demands = self.generate_demands()
batch_capacities = self.generate_capacities()
start_time = time.time()
for _ in tqdm(
range(self.samples_num // self.num_threads),
desc=f"Solving CVRP Using {self.solver_type}",
):
# call generate_func to generate the points
batch_depots_coord, batch_nodes_coord = self.generate_func()
batch_demands = self.generate_demands()
batch_capacities = self.generate_capacities()

# solve
if self.num_threads == 1:
tours = self.solver.solve(
depots=batch_depots_coord[0],
points=batch_nodes_coord[0],
demands=batch_demands[0],
capacities=batch_capacities[0]
)
tours = [tours]
else:
with Pool(self.num_threads) as p1:
tours = p1.starmap(
self.solver.solve,
[(batch_depots_coord[idx],
batch_nodes_coord[idx],
batch_demands[idx],
batch_capacities[idx])
for idx in range(self.num_threads)],
batch_nodes_coord[idx],
batch_demands[idx],
batch_capacities[idx])
for idx in range(self.num_threads)],
)
# write to txt

# write to txt
with open(self.file_save_path, "w") as f:
for idx, tour in enumerate(tours):
depot = batch_depots_coord[idx]
points = batch_nodes_coord[idx]
Expand All @@ -231,13 +244,15 @@ def generate(self):
f.write(str(" output "))
f.write(str(" ").join(str(node_idx) for node_idx in tour[0]))
f.write("\n")
end_time = time.time() - start_time
f.close()
print(
f"Completed generation of {self.samples_num} samples of CVRP{self.nodes_num}."
)
print(f"Total time: {end_time/60:.1f}m")
print(f"Average time: {end_time/self.samples_num:.1f}s")

# info
end_time = time.time() - start_time
print(
f"Completed generation of {self.samples_num} samples of CVRP{self.nodes_num}."
)
print(f"Total time: {end_time/60:.1f}m")
print(f"Average time: {end_time/self.samples_num:.1f}s")
self.devide_file()

def devide_file(self):
Expand Down
2 changes: 1 addition & 1 deletion ml4co_kit/generator/mis_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def check_solver(self):
# check solver
if type(self.solver) == str:
self.solver_type = self.solver
supported_solver_dict = {"kamis": KaMISSolver, "gurobi": MISGurobiSolver}
supported_solver_dict = {"KaMIS": KaMISSolver, "Gurobi": MISGurobiSolver}
supported_solver_type = supported_solver_dict.keys()
if self.solver not in supported_solver_type:
message = (
Expand Down
32 changes: 17 additions & 15 deletions ml4co_kit/generator/tsp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
num_threads: int = 1,
nodes_num: int = 50,
data_type: str = "uniform",
solver: Union[str, TSPSolver] = "lkh",
solver: Union[str, TSPSolver] = "LKH",
train_samples_num: int = 128000,
val_samples_num: int = 1280,
test_samples_num: int = 1280,
Expand Down Expand Up @@ -128,11 +128,11 @@ def check_solver(self):
if type(self.solver) == str:
self.solver_type = self.solver
supported_solver_dict = {
"lkh": TSPLKHSolver,
"concorde": TSPConcordeSolver,
"concorde-large": TSPConcordeLargeSolver,
"ga-eax": TSPGAEAXSolver,
"ga-eax-large": TSPGAEAXLargeSolver
"LKH": TSPLKHSolver,
"Concorde": TSPConcordeSolver,
"Concorde-Large": TSPConcordeLargeSolver,
"GA-EAX": TSPGAEAXSolver,
"GA-EAX-Large": TSPGAEAXLargeSolver
}
supported_solver_type = supported_solver_dict.keys()
if self.solver_type not in supported_solver_type:
Expand All @@ -147,11 +147,11 @@ def check_solver(self):
self.solver_type = self.solver.solver_type
# check solver
check_solver_dict = {
"lkh": self.check_lkh,
"concorde": self.check_concorde,
"concorde-large": self.check_concorde,
"ga-eax": self.check_free,
"ga-eax-large": self.check_free
"LKH": self.check_lkh,
"Concorde": self.check_concorde,
"Concorde-Large": self.check_concorde,
"GA-EAX": self.check_free,
"GA-EAX-Large": self.check_free
}
check_func = check_solver_dict[self.solver_type]
check_func()
Expand Down Expand Up @@ -241,7 +241,9 @@ def generate(self):
desc=f"Solving TSP Using {self.solver_type}",
):
# call generate_func to generate the points
batch_nodes_coord = self.generate_func()
batch_nodes_coord = self.generate_func()

# solve
if self.num_threads == 1:
tours = [self.solver.solve(batch_nodes_coord[0])]
else:
Expand All @@ -250,11 +252,11 @@ def generate(self):
self.solver.solve,
[batch_nodes_coord[idx] for idx in range(self.num_threads)],
)

# deal with regret
if self.regret:
p1.close() # Close the pool to indicate that no more tasks will be submitted
p1.join() # Wait for all processes in the pool to complete

# deal with regret
if self.regret:
if self.num_threads == 1:
self.generate_regret(tours[0], batch_nodes_coord[0], cnt)
else:
Expand Down
16 changes: 8 additions & 8 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _test_tsp_lkh_generator(
num_threads=num_threads,
nodes_num=nodes_num,
data_type=data_type,
solver="lkh",
solver="LKH",
train_samples_num=4,
val_samples_num=4,
test_samples_num=4,
Expand Down Expand Up @@ -59,7 +59,7 @@ def _test_tsp_concorde_generator(
num_threads=num_threads,
nodes_num=nodes_num,
data_type=data_type,
solver="concorde",
solver="Concorde",
train_samples_num=4,
val_samples_num=4,
test_samples_num=4,
Expand Down Expand Up @@ -90,7 +90,7 @@ def _test_tsp_concorde_large_generator(
num_threads=num_threads,
nodes_num=nodes_num,
data_type=data_type,
solver="concorde-large",
solver="Concorde-Large",
train_samples_num=1,
val_samples_num=0,
test_samples_num=0,
Expand Down Expand Up @@ -120,7 +120,7 @@ def _test_tsp_ga_eax_generator(
num_threads=num_threads,
nodes_num=nodes_num,
data_type=data_type,
solver="ga-eax",
solver="GA-EAX",
train_samples_num=4,
val_samples_num=4,
test_samples_num=4,
Expand Down Expand Up @@ -148,7 +148,7 @@ def _test_tsp_ga_eax_large_generator(
num_threads=num_threads,
nodes_num=nodes_num,
data_type=data_type,
solver="ga-eax-large",
solver="GA-EAX-Large",
train_samples_num=1,
val_samples_num=0,
test_samples_num=0,
Expand Down Expand Up @@ -254,7 +254,7 @@ def _test_mis_gurobi(
nodes_num_min=nodes_num_min,
nodes_num_max=nodes_num_max,
data_type=data_type,
solver="gurobi",
solver="Gurobi",
train_samples_num=2,
val_samples_num=2,
test_samples_num=2,
Expand Down Expand Up @@ -336,7 +336,7 @@ def _test_cvrp_lkh_generator(
num_threads=num_threads,
nodes_num=nodes_num,
data_type=data_type,
solver="lkh",
solver="LKH",
train_samples_num=4,
val_samples_num=4,
test_samples_num=4,
Expand Down Expand Up @@ -365,7 +365,7 @@ def _test_cvrp_hgs_generator(
num_threads=num_threads,
nodes_num=nodes_num,
data_type=data_type,
solver="hgs",
solver="HGS",
train_samples_num=4,
val_samples_num=4,
test_samples_num=4,
Expand Down

0 comments on commit 6a1aad8

Please sign in to comment.