Skip to content

Commit

Permalink
update DataGenerator & New Version 0.2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
heatingma committed Nov 16, 2024
1 parent 67fadfd commit d8386f3
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 44 deletions.
2 changes: 1 addition & 1 deletion ml4co_kit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,5 @@
from .learning.utils import points_to_distmat, sparse_points


__version__ = "0.2.1"
__version__ = "0.2.2"
__author__ = "SJTU-ReThinkLab"
23 changes: 17 additions & 6 deletions ml4co_kit/generator/atsp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pathlib
from tqdm import tqdm
from typing import Union
from multiprocessing import Pool
from ml4co_kit.utils.type_utils import SOLVER_TYPE
from ml4co_kit.solver import ATSPSolver, ATSPLKHSolver

Expand All @@ -18,6 +17,7 @@
class ATSPDataGenerator:
def __init__(
self,
only_instance_for_us: bool = False,
num_threads: int = 1,
nodes_num: int = 55,
data_type: str = "sat",
Expand Down Expand Up @@ -70,12 +70,17 @@ def __init__(
if self.data_type == "sat":
self.nodes_num = 2 * sat_clauses_nums * sat_vars_nums + sat_clauses_nums

# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
# only instance for us
self.only_instance_for_us = only_instance_for_us
self.check_data_type()
self.check_solver()
self.get_filename()

# generate and solve
if only_instance_for_us == False:
# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
self.check_solver()
self.get_filename()

def check_num_threads(self):
self.samples_num = 0
Expand Down Expand Up @@ -177,6 +182,12 @@ def get_filename(self):
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)

def generate_only_instance_for_us(self, samples: int) -> np.ndarray:
self.num_threads = samples
dists = self.generate_func()[0]
self.solver.from_data(dists=dists)
return self.solver.dists

def generate(self):
start_time = time.time()
for _ in tqdm(
Expand Down
38 changes: 31 additions & 7 deletions ml4co_kit/generator/cvrp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import numpy as np
import pathlib
from tqdm import tqdm
from typing import Union
from multiprocessing import Pool
from typing import Union, Sequence
from ml4co_kit.utils.type_utils import SOLVER_TYPE
from ml4co_kit.solver import (
CVRPSolver, CVRPPyVRPSolver, CVRPLKHSolver, CVRPHGSSolver
Expand All @@ -16,6 +15,7 @@
class CVRPDataGenerator:
def __init__(
self,
only_instance_for_us: bool = False,
num_threads: int = 1,
nodes_num: int = 50,
data_type: str = "uniform",
Expand Down Expand Up @@ -73,21 +73,29 @@ def __init__(
self.test_samples_num = test_samples_num
self.save_path = save_path
self.filename = filename

# special for demand and capacity
self.min_demand = min_demand
self.max_demand = max_demand
self.min_capacity = min_capacity
self.max_capacity = max_capacity

# special for gaussian
self.gaussian_mean_x = gaussian_mean_x
self.gaussian_mean_y = gaussian_mean_y
self.gaussian_std = gaussian_std
# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()

# only instance for us
self.only_instance_for_us = only_instance_for_us
self.check_data_type()
self.check_solver()
self.get_filename()

# generate and solve
if only_instance_for_us == False:
# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
self.check_solver()
self.get_filename()

def check_num_threads(self):
self.samples_num = 0
Expand Down Expand Up @@ -195,6 +203,22 @@ def get_filename(self):
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)

def generate_only_instance_for_us(self, samples: int) -> Sequence[np.ndarray]:
self.num_threads = samples
batch_depots_coord, batch_nodes_coord = self.generate_func()
batch_demands = self.generate_demands()
batch_capacities = self.generate_capacities()
self.solver.from_data(
depots=batch_depots_coord,
points=batch_nodes_coord,
demands=batch_demands,
capacities=batch_capacities
)
return (
self.solver.depots, self.solver.points,
self.solver.demands, self.solver.capacities
)

def generate(self):
start_time = time.time()
for _ in tqdm(
Expand Down
25 changes: 18 additions & 7 deletions ml4co_kit/generator/mcl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import networkx as nx
from tqdm import tqdm
from typing import Union
from typing import Union, List
from ml4co_kit.utils.graph.mcl import MClGraphData
from ml4co_kit.utils.type_utils import SOLVER_TYPE
from ml4co_kit.solver import MClSolver, MClGurobiSolver
Expand All @@ -15,6 +15,7 @@
class MClDataGenerator:
def __init__(
self,
only_instance_for_us: bool = False,
num_threads: int = 1,
nodes_num_min: int = 700,
nodes_num_max: int = 800,
Expand Down Expand Up @@ -92,13 +93,18 @@ def __init__(
self.ws_prob = ws_prob
self.ws_ring_neighbors = ws_ring_neighbors

# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
# only instance for us
self.only_instance_for_us = only_instance_for_us
self.check_data_type()
self.check_solver()
self.check_save_path()
self.get_filename()

# generate and solve
if only_instance_for_us == False:
# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
self.check_solver()
self.check_save_path()
self.get_filename()

def check_num_threads(self):
self.samples_num = 0
Expand Down Expand Up @@ -187,6 +193,11 @@ def check_free(self):
def random_weight(self, n, mu=1, sigma=0.1):
return np.around(np.random.normal(mu, sigma, n)).astype(int).clip(min=0)

def generate_only_instance_for_us(self, samples: int) -> List[MClGraphData]:
nx_graphs = [self.generate_func() for _ in range(samples)]
self.solver.from_nx_graph(nx_graphs=nx_graphs)
return self.solver.graph_data

def generate(self):
start_time = time.time()
for _ in tqdm(
Expand Down
1 change: 0 additions & 1 deletion ml4co_kit/generator/mcut_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import time
import pickle
import pathlib
import random
import numpy as np
Expand Down
27 changes: 19 additions & 8 deletions ml4co_kit/generator/mis_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import networkx as nx
from tqdm import tqdm
from typing import Union
from typing import Union, List
from ml4co_kit.utils.graph.mis import MISGraphData
from ml4co_kit.utils.type_utils import SOLVER_TYPE
from ml4co_kit.solver import MISSolver, KaMISSolver, MISGurobiSolver
Expand All @@ -15,6 +15,7 @@
class MISDataGenerator:
def __init__(
self,
only_instance_for_us: bool = False,
num_threads: int = 1,
nodes_num_min: int = 700,
nodes_num_max: int = 800,
Expand Down Expand Up @@ -92,14 +93,19 @@ def __init__(
self.ws_prob = ws_prob
self.ws_ring_neighbors = ws_ring_neighbors

# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
# only instance for us
self.only_instance_for_us = only_instance_for_us
self.check_data_type()
self.check_solver()
self.check_save_path()
self.get_filename()


# generate and solve
if only_instance_for_us == False:
# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
self.check_solver()
self.check_save_path()
self.get_filename()

def check_num_threads(self):
self.samples_num = 0
for sample_type in self.sample_types:
Expand Down Expand Up @@ -194,6 +200,11 @@ def check_free(self):
def random_weight(self, n, mu=1, sigma=0.1):
return np.around(np.random.normal(mu, sigma, n)).astype(int).clip(min=0)

def generate_only_instance_for_us(self, samples: int) -> List[MISGraphData]:
nx_graphs = [self.generate_func() for _ in range(samples)]
self.solver.from_nx_graph(nx_graphs=nx_graphs)
return self.solver.graph_data

def generate(self):
if self.solver_type == SOLVER_TYPE.KAMIS:
# check
Expand Down
26 changes: 18 additions & 8 deletions ml4co_kit/generator/mvc_data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import time
import pickle
import pathlib
import random
import numpy as np
import networkx as nx
from tqdm import tqdm
from typing import Union
from typing import Union, List
from ml4co_kit.utils.graph.mvc import MVCGraphData
from ml4co_kit.utils.type_utils import SOLVER_TYPE
from ml4co_kit.solver import MVCSolver, MVCGurobiSolver
Expand All @@ -15,6 +14,7 @@
class MVCDataGenerator:
def __init__(
self,
only_instance_for_us: bool = False,
num_threads: int = 1,
nodes_num_min: int = 700,
nodes_num_max: int = 800,
Expand Down Expand Up @@ -92,13 +92,18 @@ def __init__(
self.ws_prob = ws_prob
self.ws_ring_neighbors = ws_ring_neighbors

# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
# only instance for us
self.only_instance_for_us = only_instance_for_us
self.check_data_type()
self.check_solver()
self.check_save_path()
self.get_filename()

# generate and solve
if only_instance_for_us == False:
# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
self.check_solver()
self.check_save_path()
self.get_filename()

def check_num_threads(self):
self.samples_num = 0
Expand Down Expand Up @@ -187,6 +192,11 @@ def check_free(self):
def random_weight(self, n, mu=1, sigma=0.1):
return np.around(np.random.normal(mu, sigma, n)).astype(int).clip(min=0)

def generate_only_instance_for_us(self, samples: int) -> List[MVCGraphData]:
nx_graphs = [self.generate_func() for _ in range(samples)]
self.solver.from_nx_graph(nx_graphs=nx_graphs)
return self.solver.graph_data

def generate(self):
start_time = time.time()
for _ in tqdm(
Expand Down
24 changes: 18 additions & 6 deletions ml4co_kit/generator/tsp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class TSPDataGenerator:
def __init__(
self,
only_instance_for_us: bool = False,
num_threads: int = 1,
nodes_num: int = 50,
data_type: str = "uniform",
Expand Down Expand Up @@ -96,13 +97,18 @@ def __init__(
self.regret_save_path = regret_save_path
self.regret_solver = regret_solver

# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
# only instance for us
self.only_instance_for_us = only_instance_for_us
self.check_data_type()
self.check_solver()
self.get_filename()
self.check_regret()

# generate and solve
if only_instance_for_us == False:
# check the input variables
self.sample_types = ["train", "val", "test"]
self.check_num_threads()
self.check_solver()
self.get_filename()
self.check_regret()

def check_num_threads(self):
self.samples_num = 0
Expand Down Expand Up @@ -239,6 +245,12 @@ def check_regret(self):
if not os.path.exists(self.regret_save_path):
os.makedirs(self.regret_save_path)

def generate_only_instance_for_us(self, samples: int) -> np.ndarray:
self.num_threads = samples
points = self.generate_func()
self.solver.from_data(points=points)
return self.solver.points

def generate(self):
start_time = time.time()
cnt = 0
Expand Down

0 comments on commit d8386f3

Please sign in to comment.