-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimulation.py
259 lines (229 loc) · 8.5 KB
/
simulation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import tskit
import numpy as np
import scipy.sparse as sparse
import numba
from numba import i4, f8
from numba.experimental import jitclass
import jax.numpy as jnp
dtype = jnp.bfloat16
spec = [
('sample_weights', f8[:]),
('parent', i4[:]),
('num_samples', i4[:]),
('edges_left', f8[:]),
('edges_right', f8[:]),
('edges_parent', i4[:]),
('edges_child', i4[:]),
('edge_insertion_order', i4[:]),
('edge_removal_order', i4[:]),
('sequence_length', f8),
('nodes_time', f8[:]),
('samples', i4[:]),
('position', f8),
('virtual_root', i4),
('x', f8[:]),
('w', f8[:]),
('stack', f8[:]),
('NULL', i4)
]
@jitclass(spec)
class TraitVector:
def __init__(
self,
num_nodes,
samples,
nodes_time,
edges_left,
edges_right,
edges_parent,
edges_child,
edge_insertion_order,
edge_removal_order,
sequence_length
):
# virtual root is at num_nodes; virtual samples are beyond that
N = num_nodes + 1 + len(samples)
# Quintuply linked tree
self.parent = np.full(N, -1, dtype=np.int32)
# Sample lists refer to sample *index*
self.num_samples = np.full(N, 0, dtype=np.int32)
# Edges and indexes
self.edges_left = edges_left
self.edges_right = edges_right
self.edges_parent = edges_parent
self.edges_child = edges_child
self.edge_insertion_order = edge_insertion_order
self.edge_removal_order = edge_removal_order
self.sequence_length = sequence_length
self.nodes_time = nodes_time
self.samples = samples
self.position = 0
self.virtual_root = num_nodes
self.x = np.zeros(N, dtype=np.float64)
self.stack = np.zeros(N, dtype=np.float64)
self.NULL = -1 # to avoid tskit.NULL in numba
for j, u in enumerate(samples):
self.num_samples[u] = 1
# Add branch to the virtual sample
v = num_nodes + 1 + j
self.parent[v] = u
self.num_samples[v] = 1
def remove_edge(self, p, c):
self.stack[c] += self.get_z(c)
self.x[c] = self.position
self.parent[c] = -1
self.adjust_path_up(c, p, -1)
def insert_edge(self, p, c):
self.adjust_path_up(c, p, +1)
self.x[c] = self.position
self.parent[c] = p
def adjust_path_up(self, c, p, sign):
# sign = -1 for removing edges, +1 for adding
while p != self.NULL:
self.stack[p] += self.get_z(p)
self.x[p] = self.position
# check for floating point error
prev_stack = self.stack[c]
self.stack[c] -= sign * self.stack[p]
p = self.parent[p]
def get_z(self, u):
p = self.parent[u]
if p == self.NULL or u >= self.virtual_root:
return 0.0
time = self.nodes_time[p] - self.nodes_time[u]
span = self.position - self.x[u]
return np.sqrt(time * span) * np.random.normal()
def run(self):
sequence_length = self.sequence_length
M = self.edges_left.shape[0]
in_order = self.edge_insertion_order
out_order = self.edge_removal_order
edges_left = self.edges_left
edges_right = self.edges_right
edges_parent = self.edges_parent
edges_child = self.edges_child
j = 0
k = 0
# TODO: self.position is redundant with left
left = 0
self.position = left
while k < M and left <= self.sequence_length:
while k < M and edges_right[out_order[k]] == left:
p = edges_parent[out_order[k]]
c = edges_child[out_order[k]]
self.remove_edge(p, c)
k += 1
while j < M and edges_left[in_order[j]] == left:
p = edges_parent[in_order[j]]
c = edges_child[in_order[j]]
self.insert_edge(p, c)
j += 1
right = sequence_length
if j < M:
right = min(right, edges_left[in_order[j]])
if k < M:
right = min(right, edges_right[out_order[k]])
left = right
self.position = left
# clear remaining things down to virtual samples
for j, u in enumerate(self.samples):
v = self.virtual_root + 1 + j
self.remove_edge(u, v)
out = np.zeros(len(self.samples))
for out_i in range(len(self.samples)):
i = out_i + self.virtual_root + 1
out[out_i] = self.stack[i]
return out
def genetic_value(ts, ij, **kwargs):
def bincount_fn(w):
return np.bincount(sample_individuals, w)
rv = TraitVector(
ts.num_nodes,
samples=ts.samples(),
nodes_time=ts.nodes_time,
edges_left=ts.edges_left,
edges_right=ts.edges_right,
edges_parent=ts.edges_parent,
edges_child=ts.edges_child,
edge_insertion_order=ts.indexes_edge_insertion_order,
edge_removal_order=ts.indexes_edge_removal_order,
sequence_length=ts.sequence_length,
**kwargs,
)
g_nodes = rv.run()
#samples, sample_individuals = ij[:,0], ij[:,1]
#g_individuals = np.apply_along_axis(bincount_fn, axis=0, arr=g_nodes)
ploidy = ts.individual(0).nodes.size
g_individuals = g_nodes.reshape(-1, ploidy).sum(axis=1)
return g_individuals
class DataLoader():
def __init__(self,
ts: tskit.TreeSequence,
norm_factor: float,
receivers: np.ndarray,
senders: np.ndarray,
tau_range: list = [0.01, 1],
sigma_range: list = [0.1, 1],
batch_size: int = 200,
missing_range: list = [0.5, 1],
) -> (np.ndarray, np.ndarray):
self.ts = ts
self.norm_factor = norm_factor
self.receivers = receivers
self.senders = senders
self.tau_range = tau_range
self.sigma_range = sigma_range
self.batch_size = batch_size
self.missing_range = missing_range
self.num_nodes = ts.num_individuals
self.num_edges = receivers.size
individuals = [i.id for i in ts.individuals()]
self.ij = np.vstack([
[n,k]
for k, i in enumerate(individuals)
for n in ts.individual(i).nodes
])
ploidy = ts.individual(0).nodes.size
self.ploidy = ploidy
def __iter__(self):
return self
def __next__(self):
traits = np.empty((self.batch_size, self.ts.num_individuals, 1))
params = np.empty((self.batch_size, 2))
factors = np.empty(self.batch_size)
nodes_padding = np.ones((self.batch_size, self.num_nodes))
edges_padding = np.ones((self.batch_size, self.num_edges))
for i in range(self.batch_size):
# sample params
tau = np.random.uniform(*self.tau_range)
sigma = np.random.uniform(*self.sigma_range)
# sample trait
g = genetic_value(self.ts, self.ij) / np.sqrt(self.norm_factor) * tau
e = np.random.normal(size=self.ts.num_individuals) * sigma
y = g + e
factor = 1.5 * y.std()
y /= factor
# construct graph
traits[i] = y[:,None]
params[i] = np.asarray([tau, sigma]) / factor
factors[i] = factor
# paddings
# probability to keep nodes
p_keep = np.random.uniform(*self.missing_range)
p_keep = np.clip(p_keep, a_min=None, a_max=1)
# sample nodes to keep
nodes_keep = np.random.binomial(1, p_keep, size=self.num_nodes)
nodes_padding[i] = nodes_keep
# pick edges
nodes_keep_idx = np.arange(self.num_nodes)[nodes_keep.astype(bool)]
receivers_keep = np.isin(self.receivers, nodes_keep_idx).astype(int)
senders_keep = np.isin(self.senders, nodes_keep_idx).astype(int)
edges_padding[i] = receivers_keep * senders_keep
traits, params, factors, nodes_padding, edges_padding = (
jnp.array(traits, dtype=dtype),
jnp.array(params, dtype=dtype),
jnp.array(factors, dtype=dtype),
jnp.array(nodes_padding, dtype=dtype),
jnp.array(edges_padding, dtype=dtype)
)
return traits, params, factors, nodes_padding, edges_padding