3
3
4
4
import numpy as np
5
5
import ot .backend
6
+ from ot .lp import emd
6
7
import pandas as pd
7
8
import tempfile
8
9
9
10
from paste import pairwise_align , center_align
10
- from paste .PASTE import center_ot , intersect , center_NMF
11
-
11
+ from paste .PASTE import (
12
+ center_ot ,
13
+ intersect ,
14
+ center_NMF ,
15
+ extract_data_matrix ,
16
+ kl_divergence_backend ,
17
+ to_dense_array ,
18
+ my_fused_gromov_wasserstein ,
19
+ solve_gromov_linesearch ,
20
+ )
21
+ from pandas .testing import assert_frame_equal
12
22
13
23
test_dir = Path (__file__ ).parent
14
24
input_dir = test_dir / "data/input"
@@ -61,17 +71,24 @@ def test_center_alignment(slices):
61
71
dissimilarity = "kl" ,
62
72
distributions = [slices [i ].obsm ["weights" ] for i in range (len (slices ))],
63
73
)
64
- pd .DataFrame (center_slice .uns ["paste_W" ], index = center_slice .obs .index ).to_csv (
65
- temp_dir / "W_center.csv"
74
+ assert_frame_equal (
75
+ pd .DataFrame (
76
+ center_slice .uns ["paste_W" ],
77
+ index = center_slice .obs .index ,
78
+ columns = [str (i ) for i in range (15 )],
79
+ ),
80
+ pd .read_csv (output_dir / "W_center.csv" , index_col = 0 ),
81
+ check_names = False ,
82
+ rtol = 1e-05 ,
83
+ atol = 1e-08 ,
66
84
)
67
- pd .DataFrame (center_slice .uns ["paste_H" ], columns = center_slice .var .index ).to_csv (
68
- temp_dir / "H_center.csv"
85
+ assert_frame_equal (
86
+ pd .DataFrame (center_slice .uns ["paste_H" ], columns = center_slice .var .index ),
87
+ pd .read_csv (output_dir / "H_center.csv" ),
88
+ rtol = 1e-05 ,
89
+ atol = 1e-08 ,
69
90
)
70
91
71
- # TODO: The following computations seem to be architecture dependent (need to look into as for how)
72
- # assert_checksum_equals(temp_dir, "W_center.csv")
73
- # assert_checksum_equals(temp_dir, "H_center.csv")
74
-
75
92
for i , pi in enumerate (pairwise_info ):
76
93
pd .DataFrame (
77
94
pi , index = center_slice .obs .index , columns = slices [i ].obs .index
@@ -118,7 +135,6 @@ def test_center_ot(slices):
118
135
119
136
120
137
def test_center_NMF (intersecting_slices ):
121
- temp_dir = Path (tempfile .mkdtemp ())
122
138
n_slices = len (intersecting_slices )
123
139
124
140
pairwise_info = [
@@ -136,8 +152,106 @@ def test_center_NMF(intersecting_slices):
136
152
random_seed = 0 ,
137
153
)
138
154
139
- pd .DataFrame (_W ).to_csv (temp_dir / "W_center_NMF.csv" )
140
- pd .DataFrame (_H ).to_csv (temp_dir / "H_center_NMF.csv" )
141
- # TODO: The following computations seem to be architecture dependent (need to look into as for how)
142
- # assert_checksum_equals(temp_dir, "W_center_NMF.csv")
143
- # assert_checksum_equals(temp_dir, "H_center_NMF.csv")
155
+ assert_frame_equal (
156
+ pd .DataFrame (
157
+ _W ,
158
+ index = intersecting_slices [0 ].obs .index ,
159
+ columns = [str (i ) for i in range (15 )],
160
+ ),
161
+ pd .read_csv (output_dir / "W_center_NMF.csv" , index_col = 0 ),
162
+ rtol = 1e-05 ,
163
+ atol = 1e-08 ,
164
+ )
165
+ assert_frame_equal (
166
+ pd .DataFrame (_H , columns = intersecting_slices [0 ].var .index ),
167
+ pd .read_csv (output_dir / "H_center_NMF.csv" ),
168
+ rtol = 1e-05 ,
169
+ atol = 1e-08 ,
170
+ )
171
+
172
+
173
+ def test_fused_gromov_wasserstein (slices ):
174
+ temp_dir = Path (tempfile .mkdtemp ())
175
+
176
+ common_genes = intersect (slices [0 ].var .index , slices [1 ].var .index )
177
+ sliceA = slices [0 ][:, common_genes ]
178
+ sliceB = slices [1 ][:, common_genes ]
179
+
180
+ nx = ot .backend .NumpyBackend ()
181
+ slice1_dist = ot .dist (
182
+ nx .from_numpy (sliceA .obsm ["spatial" ]),
183
+ nx .from_numpy (sliceA .obsm ["spatial" ]),
184
+ metric = "euclidean" ,
185
+ )
186
+ slice2_dist = ot .dist (
187
+ nx .from_numpy (sliceB .obsm ["spatial" ]),
188
+ nx .from_numpy (sliceB .obsm ["spatial" ]),
189
+ metric = "euclidean" ,
190
+ )
191
+ slice1_distr = nx .ones ((sliceA .shape [0 ],)) / sliceA .shape [0 ]
192
+ slice2_distr = nx .ones ((sliceB .shape [0 ],)) / sliceB .shape [0 ]
193
+
194
+ slice1_X = nx .from_numpy (to_dense_array (extract_data_matrix (sliceA , None )))
195
+ slice2_X = nx .from_numpy (to_dense_array (extract_data_matrix (sliceB , None )))
196
+
197
+ M = nx .from_numpy (kl_divergence_backend (slice1_X + 0.01 , slice2_X + 0.01 ))
198
+
199
+ pairwise_info , log = my_fused_gromov_wasserstein (
200
+ M ,
201
+ slice1_dist ,
202
+ slice2_dist ,
203
+ slice1_distr ,
204
+ slice2_distr ,
205
+ G_init = None ,
206
+ loss_fun = "square_loss" ,
207
+ alpha = 0.1 ,
208
+ log = True ,
209
+ numItermax = 200 ,
210
+ )
211
+ pd .DataFrame (pairwise_info ).to_csv (temp_dir / "fused_gromov_wasserstein.csv" )
212
+ # TODO: Need to figure out where the randomness is coming from
213
+ # assert_checksum_equals(temp_dir, "fused_gromov_wasserstein.csv")
214
+
215
+
216
+ def test_gromov_linesearch (slices ):
217
+ common_genes = intersect (slices [1 ].var .index , slices [2 ].var .index )
218
+ sliceA = slices [1 ][:, common_genes ]
219
+ sliceB = slices [2 ][:, common_genes ]
220
+
221
+ nx = ot .backend .NumpyBackend ()
222
+ slice1_dist = ot .dist (
223
+ nx .from_numpy (sliceA .obsm ["spatial" ]),
224
+ nx .from_numpy (sliceA .obsm ["spatial" ]),
225
+ metric = "euclidean" ,
226
+ )
227
+ slice2_dist = ot .dist (
228
+ nx .from_numpy (sliceB .obsm ["spatial" ]),
229
+ nx .from_numpy (sliceB .obsm ["spatial" ]),
230
+ metric = "euclidean" ,
231
+ )
232
+ slice1_distr = nx .ones ((sliceA .shape [0 ],)) / sliceA .shape [0 ]
233
+ slice2_distr = nx .ones ((sliceB .shape [0 ],)) / sliceB .shape [0 ]
234
+
235
+ slice1_X = nx .from_numpy (to_dense_array (extract_data_matrix (sliceA , None )))
236
+ slice2_X = nx .from_numpy (to_dense_array (extract_data_matrix (sliceB , None )))
237
+
238
+ M = nx .from_numpy (kl_divergence_backend (slice1_X + 0.01 , slice2_X + 0.01 ))
239
+ slice1_distr , slice2_distr = ot .utils .list_to_array (slice1_distr , slice2_distr )
240
+
241
+ constC , hC1 , hC2 = ot .gromov .init_matrix (
242
+ slice1_dist , slice2_dist , slice1_distr , slice2_distr , loss_fun = "square_loss"
243
+ )
244
+
245
+ G = slice1_distr [:, None ] * slice2_distr [None , :]
246
+ Mi = M + 0.1 + ot .gromov .gwggrad (constC , hC1 , hC2 , G )
247
+ Mi = Mi + nx .min (Mi )
248
+
249
+ Gc = emd (slice1_distr , slice2_distr , Mi )
250
+ deltaG = Gc - G
251
+ costG = nx .sum (M * G ) + 0.1 * ot .gromov .gwloss (constC , hC1 , hC2 , G )
252
+ alpha , fc , cost_G = solve_gromov_linesearch (
253
+ G , deltaG , costG , slice1_dist , slice2_dist , M = 0.0 , reg = 1.0 , nx = nx
254
+ )
255
+ assert alpha == 1.0
256
+ assert fc == 1
257
+ assert round (cost_G ,6 ) == - 11.419226
0 commit comments