-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsimulator.py
94 lines (76 loc) · 2.85 KB
/
simulator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from tqdm import tqdm
from graph_helpers import observe_uninfected_node
from random_steiner_tree.util import isolate_vertex
from experiment import gen_input
from query_selection import NoMoreQuery
class Simulator():
def __init__(self, g, query_generator, gi=None, print_log=False):
"""
g: graph_tool.Graph or graph_tool.GraphView
gi: random_steiner_tree.Graph
"""
self.g = g
self.gi = gi
self.q_gen = query_generator
self.print_log = print_log
def run(self, n_queries, obs=None, c=None, gen_input_kwargs={},
iter_callback=None):
"""return the list of query nodes
"""
if obs is None or c is None:
obs, c = gen_input(self.g, **gen_input_kwargs)[:2]
self.q_gen.receive_observation(obs, c)
aux = {'graph_changed': False,
'obs': obs,
'c': c}
qs = []
inf_nodes = list(obs)
uninf_nodes = []
if self.print_log:
iters = tqdm(range(n_queries), total=n_queries)
else:
iters = range(n_queries)
for i in iters:
try:
q = self.q_gen.select_query(self.g, inf_nodes)
except NoMoreQuery:
if self.print_log:
print('no more nodes to query. queried {} nodes'.format(len(qs)))
break
# print('query:', q)
qs.append(q)
if len(qs) == n_queries:
print('num. queries reached')
break
if c[q] == -1: # not infected
if self.print_log:
# print('isolating node {} started'.format(q))
pass
observe_uninfected_node(self.g, q, inf_nodes)
if self.gi is not None:
isolate_vertex(self.gi, q)
if self.print_log:
# print('isolating node {} done'.format(q))
pass
self.q_gen.update_pool(self.g)
aux['graph_changed'] = True
uninf_nodes.append(q)
else:
inf_nodes.append(q)
# update tree samples if necessary
if self.print_log:
print('update samples started')
pass
label = int(c[q] >= 0)
assert label in {0, 1}
# print('update samples, node {} label {}'.format(q, label))
try:
self.q_gen.update_observation(self.g, inf_nodes, q, label, c)
except NoMoreQuery:
print('no more queries')
break
if self.print_log:
print('update samples done')
if callable(iter_callback):
iter_callback(self.g, self.q_gen, inf_nodes, uninf_nodes)
return qs, aux