Skip to content

Commit d55e6f6

Browse files
authored
added more test functions and files, implementing assert frame equals from pandas
1 parent c81f2ad commit d55e6f6

File tree

6 files changed

+672
-304
lines changed

6 files changed

+672
-304
lines changed

src/paste/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .PASTE import pairwise_align, center_align, center_ot, center_NMF
2-
from .helper import match_spots_using_spatial_heuristic, filter_for_common_genes, apply_trsf, intersect
1+
from .PASTE import pairwise_align, center_align, center_ot, center_NMF, my_fused_gromov_wasserstein, solve_gromov_linesearch
2+
from .helper import match_spots_using_spatial_heuristic, filter_for_common_genes, apply_trsf, intersect,extract_data_matrix, to_dense_array, kl_divergence_backend
33
from .visualization import plot_slice, stack_slices_pairwise, stack_slices_center

tests/data/output/H_center.csv

Lines changed: 16 additions & 16 deletions
Large diffs are not rendered by default.

tests/data/output/H_center_NMF.csv

Lines changed: 16 additions & 16 deletions
Large diffs are not rendered by default.

tests/data/output/W_center_NMF.csv

Lines changed: 254 additions & 254 deletions
Large diffs are not rendered by default.

tests/data/output/fused_gromov_wasserstein.csv

Lines changed: 254 additions & 0 deletions
Large diffs are not rendered by default.

tests/test_paste.py

Lines changed: 130 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,22 @@
33

44
import numpy as np
55
import ot.backend
6+
from ot.lp import emd
67
import pandas as pd
78
import tempfile
89

910
from paste import pairwise_align, center_align
10-
from paste.PASTE import center_ot, intersect, center_NMF
11-
11+
from paste.PASTE import (
12+
center_ot,
13+
intersect,
14+
center_NMF,
15+
extract_data_matrix,
16+
kl_divergence_backend,
17+
to_dense_array,
18+
my_fused_gromov_wasserstein,
19+
solve_gromov_linesearch,
20+
)
21+
from pandas.testing import assert_frame_equal
1222

1323
test_dir = Path(__file__).parent
1424
input_dir = test_dir / "data/input"
@@ -61,17 +71,24 @@ def test_center_alignment(slices):
6171
dissimilarity="kl",
6272
distributions=[slices[i].obsm["weights"] for i in range(len(slices))],
6373
)
64-
pd.DataFrame(center_slice.uns["paste_W"], index=center_slice.obs.index).to_csv(
65-
temp_dir / "W_center.csv"
74+
assert_frame_equal(
75+
pd.DataFrame(
76+
center_slice.uns["paste_W"],
77+
index=center_slice.obs.index,
78+
columns=[str(i) for i in range(15)],
79+
),
80+
pd.read_csv(output_dir / "W_center.csv", index_col=0),
81+
check_names=False,
82+
rtol=1e-05,
83+
atol=1e-08,
6684
)
67-
pd.DataFrame(center_slice.uns["paste_H"], columns=center_slice.var.index).to_csv(
68-
temp_dir / "H_center.csv"
85+
assert_frame_equal(
86+
pd.DataFrame(center_slice.uns["paste_H"], columns=center_slice.var.index),
87+
pd.read_csv(output_dir / "H_center.csv"),
88+
rtol=1e-05,
89+
atol=1e-08,
6990
)
7091

71-
# TODO: The following computations seem to be architecture dependent (need to look into as for how)
72-
# assert_checksum_equals(temp_dir, "W_center.csv")
73-
# assert_checksum_equals(temp_dir, "H_center.csv")
74-
7592
for i, pi in enumerate(pairwise_info):
7693
pd.DataFrame(
7794
pi, index=center_slice.obs.index, columns=slices[i].obs.index
@@ -118,7 +135,6 @@ def test_center_ot(slices):
118135

119136

120137
def test_center_NMF(intersecting_slices):
121-
temp_dir = Path(tempfile.mkdtemp())
122138
n_slices = len(intersecting_slices)
123139

124140
pairwise_info = [
@@ -136,8 +152,106 @@ def test_center_NMF(intersecting_slices):
136152
random_seed=0,
137153
)
138154

139-
pd.DataFrame(_W).to_csv(temp_dir / "W_center_NMF.csv")
140-
pd.DataFrame(_H).to_csv(temp_dir / "H_center_NMF.csv")
141-
# TODO: The following computations seem to be architecture dependent (need to look into as for how)
142-
# assert_checksum_equals(temp_dir, "W_center_NMF.csv")
143-
# assert_checksum_equals(temp_dir, "H_center_NMF.csv")
155+
assert_frame_equal(
156+
pd.DataFrame(
157+
_W,
158+
index=intersecting_slices[0].obs.index,
159+
columns=[str(i) for i in range(15)],
160+
),
161+
pd.read_csv(output_dir / "W_center_NMF.csv", index_col=0),
162+
rtol=1e-05,
163+
atol=1e-08,
164+
)
165+
assert_frame_equal(
166+
pd.DataFrame(_H, columns=intersecting_slices[0].var.index),
167+
pd.read_csv(output_dir / "H_center_NMF.csv"),
168+
rtol=1e-05,
169+
atol=1e-08,
170+
)
171+
172+
173+
def test_fused_gromov_wasserstein(slices):
174+
temp_dir = Path(tempfile.mkdtemp())
175+
176+
common_genes = intersect(slices[0].var.index, slices[1].var.index)
177+
sliceA = slices[0][:, common_genes]
178+
sliceB = slices[1][:, common_genes]
179+
180+
nx = ot.backend.NumpyBackend()
181+
slice1_dist = ot.dist(
182+
nx.from_numpy(sliceA.obsm["spatial"]),
183+
nx.from_numpy(sliceA.obsm["spatial"]),
184+
metric="euclidean",
185+
)
186+
slice2_dist = ot.dist(
187+
nx.from_numpy(sliceB.obsm["spatial"]),
188+
nx.from_numpy(sliceB.obsm["spatial"]),
189+
metric="euclidean",
190+
)
191+
slice1_distr = nx.ones((sliceA.shape[0],)) / sliceA.shape[0]
192+
slice2_distr = nx.ones((sliceB.shape[0],)) / sliceB.shape[0]
193+
194+
slice1_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceA, None)))
195+
slice2_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceB, None)))
196+
197+
M = nx.from_numpy(kl_divergence_backend(slice1_X + 0.01, slice2_X + 0.01))
198+
199+
pairwise_info, log = my_fused_gromov_wasserstein(
200+
M,
201+
slice1_dist,
202+
slice2_dist,
203+
slice1_distr,
204+
slice2_distr,
205+
G_init=None,
206+
loss_fun="square_loss",
207+
alpha=0.1,
208+
log=True,
209+
numItermax=200,
210+
)
211+
pd.DataFrame(pairwise_info).to_csv(temp_dir / "fused_gromov_wasserstein.csv")
212+
# TODO: Need to figure out where the randomness is coming from
213+
# assert_checksum_equals(temp_dir, "fused_gromov_wasserstein.csv")
214+
215+
216+
def test_gromov_linesearch(slices):
217+
common_genes = intersect(slices[1].var.index, slices[2].var.index)
218+
sliceA = slices[1][:, common_genes]
219+
sliceB = slices[2][:, common_genes]
220+
221+
nx = ot.backend.NumpyBackend()
222+
slice1_dist = ot.dist(
223+
nx.from_numpy(sliceA.obsm["spatial"]),
224+
nx.from_numpy(sliceA.obsm["spatial"]),
225+
metric="euclidean",
226+
)
227+
slice2_dist = ot.dist(
228+
nx.from_numpy(sliceB.obsm["spatial"]),
229+
nx.from_numpy(sliceB.obsm["spatial"]),
230+
metric="euclidean",
231+
)
232+
slice1_distr = nx.ones((sliceA.shape[0],)) / sliceA.shape[0]
233+
slice2_distr = nx.ones((sliceB.shape[0],)) / sliceB.shape[0]
234+
235+
slice1_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceA, None)))
236+
slice2_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceB, None)))
237+
238+
M = nx.from_numpy(kl_divergence_backend(slice1_X + 0.01, slice2_X + 0.01))
239+
slice1_distr, slice2_distr = ot.utils.list_to_array(slice1_distr, slice2_distr)
240+
241+
constC, hC1, hC2 = ot.gromov.init_matrix(
242+
slice1_dist, slice2_dist, slice1_distr, slice2_distr, loss_fun="square_loss"
243+
)
244+
245+
G = slice1_distr[:, None] * slice2_distr[None, :]
246+
Mi = M + 0.1 + ot.gromov.gwggrad(constC, hC1, hC2, G)
247+
Mi = Mi + nx.min(Mi)
248+
249+
Gc = emd(slice1_distr, slice2_distr, Mi)
250+
deltaG = Gc - G
251+
costG = nx.sum(M * G) + 0.1 * ot.gromov.gwloss(constC, hC1, hC2, G)
252+
alpha, fc, cost_G = solve_gromov_linesearch(
253+
G, deltaG, costG, slice1_dist, slice2_dist, M=0.0, reg=1.0, nx=nx
254+
)
255+
assert alpha == 1.0
256+
assert fc == 1
257+
assert round(cost_G,6) == -11.419226

0 commit comments

Comments
 (0)