-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcore.py
executable file
·74 lines (67 loc) · 2.81 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import random
from graph_helpers import (extract_steiner_tree,
gen_random_spanning_tree,
filter_graph_by_edges,
reachable_node_set,
swap_end_points)
from tqdm import tqdm
from random_steiner_tree import random_steiner_tree
def sample_steiner_trees(g, obs,
method,
n_samples,
gi=None,
root=None,
root_sampler=None,
return_type='nodes',
log=False,
verbose=False):
"""sample `n_samples` steiner trees that span `obs` in `g`
`method`: the method for sampling steiner tree
`n_samples`: sample size
`gi`: the Graph object that is used if `method` in {'cut', 'loop_erased'}
`root_sampler`: function that samples a root
`return_type`: if True, return the set of nodes that are in the sampled steiner tree
"""
assert method in {'cut', 'cut_naive', 'loop_erased'}
# print('n_samples', n_samples)
steiner_tree_samples = []
# for i in tqdm(range(n_samples), total=n_samples):
if log:
iters = tqdm(range(n_samples), total=n_samples)
else:
iters = range(n_samples)
for i in iters:
if root is None:
# if root not give, sample it using some sampler
if root_sampler is None:
# print('random root')
# note: isolated nodes are ignored
node_set = reachable_node_set(g, list(obs)[0])
r = int(random.choice(list(node_set)))
else:
# print('custom root sampler')
assert callable(root_sampler), 'root_sampler should be callable'
# print('root_sampler', root_sampler)
r = root_sampler()
# print('root', r)
else:
r = root
if method == 'cut_naive':
rand_t = gen_random_spanning_tree(g, root=r)
st = extract_steiner_tree(rand_t, obs, return_nodes=return_type)
# if return_type:
# st = set(map(int, st.vertices()))
elif method in {'cut', 'loop_erased'}:
assert gi is not None
# print('der')
edges = random_steiner_tree(gi, obs, r, method, verbose=verbose)
if return_type == 'nodes':
st = set(u for e in edges for u in e)
elif return_type == 'tuples':
st = swap_end_points(edges)
elif return_type == 'tree':
st = filter_graph_by_edges(g, edges)
else:
raise ValueError('unknown return_type {}'.format(return_type))
steiner_tree_samples.append(st)
return steiner_tree_samples