Skip to content

Commit 119d04a

Browse files
committed
fix bugs
1 parent 566bf5b commit 119d04a

File tree

5 files changed

+31
-30
lines changed

5 files changed

+31
-30
lines changed

ml4co_kit/generator/atsp_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,5 +359,5 @@ def generate_uniform(self) -> Union[np.ndarray, np.ndarray]:
359359
dist = (dist[:, None, :] + dist[None, :, :].transpose(0, 2, 1)).min(axis=2)
360360
if (dist == old_dist).all():
361361
break
362-
dists.append(dist / scaler)
362+
dists.append(dist / scaler)
363363
return np.array(dists), None

ml4co_kit/generator/cvrp_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,10 @@ def generate(self):
211211
depots=batch_depots_coord,
212212
points=batch_nodes_coord,
213213
demands=batch_demands,
214-
capacities=batch_capacities,
214+
capacities=batch_capacities.reshape(-1),
215215
num_threads=self.num_threads
216216
)
217-
217+
218218
# write to txt
219219
with open(self.file_save_path, "a+") as f:
220220
for idx, tour in enumerate(tours):
@@ -233,7 +233,7 @@ def generate(self):
233233
f.write(" demands " + str(" ").join(str(demand) for demand in demands))
234234
f.write(" capacity " + str(capicity))
235235
f.write(str(" output "))
236-
f.write(str(" ").join(str(node_idx) for node_idx in tour[0]))
236+
f.write(str(" ").join(str(node_idx) for node_idx in tour))
237237
f.write("\n")
238238
f.close()
239239

ml4co_kit/generator/mis_data.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def check_solver(self):
159159
if isinstance(self.solver, SOLVER_TYPE):
160160
self.solver_type = self.solver
161161
supported_solver_dict = {
162-
SOLVER_TYPE.KAMIS: KaMISSolver,
163-
SOLVER_TYPE.GUROBI: MISGurobiSolver
162+
SOLVER_TYPE.GUROBI: MISGurobiSolver,
163+
SOLVER_TYPE.KAMIS: KaMISSolver
164164
}
165165
supported_solver_type = supported_solver_dict.keys()
166166
if self.solver not in supported_solver_type:
@@ -177,6 +177,7 @@ def check_solver(self):
177177
# check solver
178178
check_solver_dict = {
179179
SOLVER_TYPE.GUROBI: self.check_free,
180+
SOLVER_TYPE.KAMIS: self.check_free
180181
}
181182
check_func = check_solver_dict[self.solver_type]
182183
check_func()

ml4co_kit/generator/tsp_data.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,11 @@ def check_solver(self):
132132
if isinstance(self.solver, SOLVER_TYPE):
133133
self.solver_type = self.solver
134134
supported_solver_dict = {
135-
"LKH": TSPLKHSolver,
136-
"Concorde": TSPConcordeSolver,
137-
"Concorde-Large": TSPConcordeLargeSolver,
138-
"GA-EAX": TSPGAEAXSolver,
139-
"GA-EAX-Large": TSPGAEAXLargeSolver
135+
SOLVER_TYPE.CONCORDE: TSPConcordeSolver,
136+
SOLVER_TYPE.LKH: TSPLKHSolver,
137+
SOLVER_TYPE.CONCORDE_LARGE: TSPConcordeLargeSolver,
138+
SOLVER_TYPE.GA_EAX: TSPGAEAXSolver,
139+
SOLVER_TYPE.GA_EAX_LARGE: TSPGAEAXLargeSolver
140140
}
141141
supported_solver_type = supported_solver_dict.keys()
142142
if self.solver_type not in supported_solver_type:

tests/test_generator.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
##############################################
1515

1616
def _test_atsp_lkh_generator(
17-
num_threads: int, nodes_num: int, data_type: str,
18-
sat_vars_num: int = None, sat_clauses_nums: int = None
17+
num_threads: int, nodes_num: int, data_type: str, sat_vars_num: int = None,
18+
sat_clauses_nums: int = None, re_download: bool = False
1919
):
2020
"""
2121
Test ATSPDataGenerator using ATSPLKHSolver
@@ -26,7 +26,7 @@ def _test_atsp_lkh_generator(
2626
os.makedirs(save_path)
2727

2828
# create TSPDataGenerator using lkh solver
29-
tsp_data_lkh = ATSPDataGenerator(
29+
atsp_data_lkh = ATSPDataGenerator(
3030
num_threads=num_threads,
3131
nodes_num=nodes_num,
3232
data_type=data_type,
@@ -38,9 +38,12 @@ def _test_atsp_lkh_generator(
3838
sat_vars_nums=sat_vars_num,
3939
sat_clauses_nums=sat_clauses_nums,
4040
)
41-
41+
42+
if re_download:
43+
atsp_data_lkh.download_lkh()
44+
4245
# generate data
43-
tsp_data_lkh.generate()
46+
atsp_data_lkh.generate()
4447

4548
# remove the save path
4649
shutil.rmtree(save_path)
@@ -50,6 +53,10 @@ def test_atsp():
5053
"""
5154
Test ATSPDataGenerator
5255
"""
56+
# uniform
57+
_test_atsp_lkh_generator(
58+
num_threads=4, nodes_num=50, data_type="uniform", re_download=True
59+
)
5360
# sat
5461
_test_atsp_lkh_generator(
5562
num_threads=4, nodes_num=55, data_type="sat", sat_clauses_nums=5, sat_vars_num=5
@@ -62,10 +69,6 @@ def test_atsp():
6269
_test_atsp_lkh_generator(
6370
num_threads=4, nodes_num=50, data_type="hcp"
6471
)
65-
# uniform
66-
_test_atsp_lkh_generator(
67-
num_threads=4, nodes_num=50, data_type="uniform"
68-
)
6972

7073

7174
##############################################
@@ -443,8 +446,7 @@ def test_mvc():
443446
##############################################
444447

445448
def _test_tsp_lkh_generator(
446-
num_threads: int, nodes_num: int, data_type: str,
447-
regret: bool, re_download: bool=False
449+
num_threads: int, nodes_num: int, data_type: str, regret: bool
448450
):
449451
"""
450452
Test TSPDataGenerator using LKH Solver
@@ -465,8 +467,7 @@ def _test_tsp_lkh_generator(
465467
save_path=save_path,
466468
regret=regret,
467469
)
468-
if re_download:
469-
tsp_data_lkh.download_lkh()
470+
470471
# generate data
471472
tsp_data_lkh.generate()
472473
# remove the save path
@@ -564,10 +565,9 @@ def test_tsp():
564565
"""
565566
Test TSPDataGenerator
566567
"""
567-
# re-download lkh
568+
# threads
568569
_test_tsp_lkh_generator(
569-
num_threads=4, nodes_num=50, data_type="uniform",
570-
regret=False, re_download=True
570+
num_threads=4, nodes_num=50, data_type="uniform", regret=False
571571
)
572572
# regret & threads
573573
_test_tsp_lkh_generator(
@@ -603,11 +603,11 @@ def test_tsp():
603603
##############################################
604604

605605
if __name__ == "__main__":
606-
test_tsp()
606+
test_atsp()
607+
test_cvrp()
607608
test_mc()
608609
test_mcl()
609610
test_mis()
610611
test_mvc()
611-
test_cvrp()
612-
test_atsp()
612+
test_tsp()
613613
shutil.rmtree("tmp")

0 commit comments

Comments
 (0)