Skip to content

Commit

Permalink
Fix bugs for PyVRP and update GAX solver
Browse files Browse the repository at this point in the history
  • Loading branch information
heatingma committed Nov 15, 2024
1 parent 5fe53e7 commit 9637cae
Show file tree
Hide file tree
Showing 25 changed files with 119 additions and 86 deletions.
9 changes: 5 additions & 4 deletions ml4co_kit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,14 @@
# Utils Function #
#######################################################
from .utils import download, compress_folder, extract_archive, _get_md5
from .utils import iterative_execution_for_file, iterative_execution
from .utils import iterative_execution_for_file, iterative_execution, Timer
from .utils import np_dense_to_sparse, np_sparse_to_dense, GraphData, tsplib95
from .utils import MISGraphData, MVCGraphData, MClGraphData, MCutGraphData
from .utils import sat_to_mis_graph, cnf_folder_to_gpickle_folder, cnf_to_gpickle

#######################################################
# Extension Function #
# Extension Function (matplotlib) #
#######################################################
# expand - matplotlib
found_matplotlib = importlib.util.find_spec("matplotlib")
if found_matplotlib is not None:
from .draw.cvrp import draw_cvrp_problem, draw_cvrp_solution
Expand All @@ -75,7 +74,9 @@
from .draw.mvc import draw_mvc_problem, draw_mvc_solution
from .draw.tsp import draw_tsp_problem, draw_tsp_solution

# expand - pytorch_lightning
#######################################################
# Extension Function (pytorch_lightning) #
#######################################################
found_pytorch_lightning = importlib.util.find_spec("pytorch_lightning")
if found_pytorch_lightning is not None:
from .learning.env import BaseEnv
Expand Down
25 changes: 18 additions & 7 deletions ml4co_kit/generator/mcut_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import networkx as nx
from tqdm import tqdm
from typing import Union
from typing import Union, List
from ml4co_kit.utils.graph.mcut import MCutGraphData
from ml4co_kit.utils.type_utils import SOLVER_TYPE
from ml4co_kit.solver import MCutSolver, MCutGurobiSolver
Expand All @@ -15,6 +15,7 @@
class MCutDataGenerator:
def __init__(
self,
only_instance_for_us: bool = False,
num_threads: int = 1,
nodes_num_min: int = 700,
nodes_num_max: int = 800,
Expand Down Expand Up @@ -92,13 +93,18 @@ def __init__(
self.ws_prob = ws_prob
self.ws_ring_neighbors = ws_ring_neighbors

# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
# only instance for us
self.only_instance_for_us = only_instance_for_us
self.check_data_type()
self.check_solver()
self.check_save_path()
self.get_filename()

# generate and solve
if only_instance_for_us == False:
# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
self.check_solver()
self.check_save_path()
self.get_filename()

def check_num_threads(self):
self.samples_num = 0
Expand Down Expand Up @@ -195,6 +201,11 @@ def check_free(self):
def random_weight(self, n, mu=1, sigma=0.1):
return np.around(np.random.normal(mu, sigma, n)).astype(int).clip(min=0)

def generate_only_instance_for_us(self, samples: int) -> List[MCutGraphData]:
nx_graphs = [self.generate_func() for _ in range(samples)]
self.solver.from_nx_graph(nx_graphs=nx_graphs)
return self.solver.graph_data

def generate(self):
start_time = time.time()
for _ in tqdm(
Expand Down
19 changes: 9 additions & 10 deletions ml4co_kit/solver/atsp/lkh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import time
import uuid
import pathlib
import numpy as np
Expand All @@ -8,7 +7,7 @@
from subprocess import check_call
from ml4co_kit.solver.atsp.base import ATSPSolver
from ml4co_kit.utils.type_utils import SOLVER_TYPE
from ml4co_kit.utils.time_utils import iterative_execution
from ml4co_kit.utils.time_utils import iterative_execution, Timer


class ATSPLKHSolver(ATSPSolver):
Expand Down Expand Up @@ -100,7 +99,8 @@ def solve(
# prepare
self.from_data(dists=dists, normalize=normalize)
self.tmp_solver = ATSPSolver(scale=self.scale)
start_time = time.time()
timer = Timer(apply=show_time)
timer.start()

# solve
tours = list()
Expand All @@ -126,14 +126,13 @@ def solve(
tours.append(tour)

# format
tours = np.array(tours)
if tours.ndim == 2 and tours.shape[0] == 1:
tours = tours[0]
self.from_data(tours=tours, ref=False)
end_time = time.time()
if show_time:
print(f"Use Time: {end_time - start_time}")
return tours

# show time
timer.end()
timer.show_time()

return self.tours

def __str__(self) -> str:
return "ATSPLKHSolver"
16 changes: 8 additions & 8 deletions ml4co_kit/solver/cvrp/hgs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import uuid
import time
import numpy as np
from typing import Union
from multiprocessing import Pool
from ml4co_kit.solver.cvrp.base import CVRPSolver
from ml4co_kit.utils.type_utils import SOLVER_TYPE
from ml4co_kit.utils.time_utils import iterative_execution
from ml4co_kit.utils.time_utils import iterative_execution, Timer
from ml4co_kit.solver.cvrp.c_hgs import cvrp_hgs_solver, HGS_TMP_PATH


Expand Down Expand Up @@ -87,9 +86,8 @@ def solve(
depots=depots, points=points, demands=demands,
capacities=capacities, norm=norm, normalize=normalize
)

# start time
start_time = time.time()
timer = Timer(apply=show_time)
timer.start()

# solve
tours = list()
Expand Down Expand Up @@ -127,9 +125,11 @@ def solve(

# format
self.from_data(tours=tours, ref=False)
end_time = time.time()
if show_time:
print(f"Use Time: {end_time - start_time}")

# show time
timer.end()
timer.show_time()

return self.tours

def __str__(self) -> str:
Expand Down
16 changes: 8 additions & 8 deletions ml4co_kit/solver/cvrp/lkh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import time
import uuid
import pathlib
import numpy as np
Expand All @@ -9,7 +8,7 @@
from ml4co_kit.utils import tsplib95
from ml4co_kit.solver.cvrp.base import CVRPSolver
from ml4co_kit.utils.type_utils import SOLVER_TYPE
from ml4co_kit.utils.time_utils import iterative_execution
from ml4co_kit.utils.time_utils import iterative_execution, Timer


class CVRPLKHSolver(CVRPSolver):
Expand Down Expand Up @@ -145,9 +144,8 @@ def solve(
capacities=capacities, norm=norm, normalize=normalize
)
self.tmp_solver = CVRPSolver()

# start time
start_time = time.time()
timer = Timer(apply=show_time)
timer.start()

# solve
tours = list()
Expand Down Expand Up @@ -187,9 +185,11 @@ def solve(

# format
self.from_data(tours=tours, ref=False)
end_time = time.time()
if show_time:
print(f"Use Time: {end_time - start_time}")

# show time
timer.end()
timer.show_time()

return self.tours

def __str__(self) -> str:
Expand Down
21 changes: 11 additions & 10 deletions ml4co_kit/solver/cvrp/pyvrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pyvrp.stop import MaxRuntime
from ml4co_kit.solver.cvrp.base import CVRPSolver
from ml4co_kit.utils.type_utils import SOLVER_TYPE
from ml4co_kit.utils.time_utils import iterative_execution
from ml4co_kit.utils.time_utils import iterative_execution, Timer


if sys.version_info.major == 3 and sys.version_info.minor == 8:
Expand Down Expand Up @@ -54,9 +54,9 @@ def _solve(
cvrp_model.add_vehicle_type(capacity=capacity, num_available=max_num_available)
clients = [
cvrp_model.add_client(
self.round_func(nodes_coord[idx][0]),
self.round_func(nodes_coord[idx][1]),
self.round_func(demands[idx])
int(self.round_func(nodes_coord[idx][0])),
int(self.round_func(nodes_coord[idx][1])),
int(self.round_func(demands[idx]))
) for idx in range(0, len(nodes_coord))
]
locations = [depot] + clients
Expand Down Expand Up @@ -91,9 +91,8 @@ def solve(
capacities=capacities, norm=norm, normalize=normalize
)
self.round_func = self.get_round_func(round_func)

# start time
start_time = time.time()
timer = Timer(apply=show_time)
timer.start()

# solve
tours = list()
Expand Down Expand Up @@ -133,9 +132,11 @@ def solve(

# format
self.from_data(tours=tours)
end_time = time.time()
if show_time:
print(f"Use Time: {end_time - start_time}")

# show time
timer.end()
timer.show_time()

return self.tours

def __str__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion ml4co_kit/solver/mcl/gurobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def solve(
graph_data: List[MClGraphData] = None,
num_threads: int = 1,
show_time: bool = False
) -> np.ndarray:
) -> List[MClGraphData]:
# preparation
if graph_data is not None:
self.graph_data = graph_data
Expand Down
2 changes: 1 addition & 1 deletion ml4co_kit/solver/mcut/gurobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def solve(
graph_data: List[MCutGraphData] = None,
num_threads: int = 1,
show_time: bool = False
) -> np.ndarray:
) -> List[MCutGraphData]:
# preparation
if graph_data is not None:
self.graph_data = graph_data
Expand Down
2 changes: 1 addition & 1 deletion ml4co_kit/solver/mis/gurobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def solve(
graph_data: List[MISGraphData] = None,
num_threads: int = 1,
show_time: bool = False
) -> np.ndarray:
) -> List[MISGraphData]:
# preparation
if graph_data is not None:
self.graph_data = graph_data
Expand Down
8 changes: 5 additions & 3 deletions ml4co_kit/solver/mis/kamis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import networkx as nx
from tqdm import tqdm
from pathlib import Path
from typing import Union
from ml4co_kit.utils.type_utils import SOLVER_TYPE
from typing import Union, List
from ml4co_kit.solver.mis.base import MISSolver
from ml4co_kit.utils.graph.mis import MISGraphData
from ml4co_kit.utils.type_utils import SOLVER_TYPE


class KaMISSolver(MISSolver):
Expand Down Expand Up @@ -82,7 +83,7 @@ def prepare_instance(

def solve(
self, src: Union[str, pathlib.Path], out: Union[str, pathlib.Path],
):
) -> List[MISGraphData]:
message = (
"Please check KaMIS compilation. "
"you can try ``self.recompile_kamis()``. "
Expand All @@ -101,6 +102,7 @@ def solve(
self.from_gpickle_result_folder(
gpickle_folder_path=src, result_folder_path=out, ref=False, cover=True
)
return self.graph_data

def _solve(self, src: Union[str, pathlib.Path], out: Union[str, pathlib.Path]):
src = Path(src)
Expand Down
2 changes: 1 addition & 1 deletion ml4co_kit/solver/mvc/gurobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def solve(
graph_data: List[MVCGraphData] = None,
num_threads: int = 1,
show_time: bool = False
) -> np.ndarray:
) -> List[MVCGraphData]:
# preparation
if graph_data is not None:
self.graph_data = graph_data
Expand Down
4 changes: 2 additions & 2 deletions ml4co_kit/solver/tsp/c_ga_eax_large/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@

def tsp_ga_eax_large_solve(
max_trials: int, sol_name: str, population_num: int,
offspring_num: int, tsp_name: str,
offspring_num: int, tsp_name: str, show_info: bool = False
):
tsp_path = os.path.join("tmp", tsp_name)
sol_path = os.path.join("tmp", sol_name)
ori_dir = os.getcwd()
os.chdir(GA_EAX_LARGE_BASE_PATH)
command = f"./ga_eax_large_solver {max_trials} {sol_path} {population_num} {offspring_num} {tsp_path}"
command = f"./ga_eax_large_solver {max_trials} {sol_path} {population_num} {offspring_num} {tsp_path} {show_info}"
os.system(command)
os.chdir(ori_dir)

Expand Down
5 changes: 3 additions & 2 deletions ml4co_kit/solver/tsp/c_ga_eax_large/env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ void TEnvironment::DoIt()
while( 1 )
{
this->SetAverageBest();
printf( "%d: %d %lf\n", fCurNumOfGen, fBestValue, fAverageValue );

if (showInfo){
printf( "%d: %d %lf\n", fCurNumOfGen, fBestValue, fAverageValue );
}
if( this->TerminationCondition() ) break;

this->SelectForMating();
Expand Down
2 changes: 2 additions & 0 deletions ml4co_kit/solver/tsp/c_ga_eax_large/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class TEnvironment {
int fMaxStagBest; /* If fStagBest = fMaxStagBest, proceed to the next stage */
int fCurNumOfGen1; /* Number of generations at which Stage I is terminated */

int showInfo;

clock_t fTimeStart, fTimeInit, fTimeEnd; /* Use them to measure the execution time */
};

Expand Down
8 changes: 5 additions & 3 deletions ml4co_kit/solver/tsp/c_ga_eax_large/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ int main( int argc, char* argv[] )
gEnv->fNumOfKids = d;
gEnv->fFileNameTSP = argv[5];
gEnv->fFileNameInitPop = NULL;
if( argc == 7 )
gEnv->fFileNameInitPop = argv[6];
sscanf( argv[6], "%d", &d );
gEnv->showInfo = d;
if( argc == 8 )
gEnv->fFileNameInitPop = argv[7];

gEnv->Define();

for( int n = 0; n < maxNumOfTrial; ++n )
{
gEnv->DoIt();

gEnv->PrintOn( n, dstFile );
if (gEnv->showInfo){gEnv->PrintOn( n, dstFile );}
gEnv->WriteBest( dstFile );
// gEnv->WritePop( n, dstFile );
}
Expand Down
Loading

0 comments on commit 9637cae

Please sign in to comment.