Skip to content

Commit

Permalink
Added more test cases and test files (#4)
Browse files Browse the repository at this point in the history
introducing test case for center_ot, center_nmf, and respective test files
  • Loading branch information
anushka255 authored Aug 8, 2024
1 parent 4f97346 commit ed59977
Show file tree
Hide file tree
Showing 15 changed files with 2,661 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/paste/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .PASTE import pairwise_align, center_align
from .helper import match_spots_using_spatial_heuristic, filter_for_common_genes, apply_trsf
from .PASTE import pairwise_align, center_align, center_ot, center_NMF
from .helper import match_spots_using_spatial_heuristic, filter_for_common_genes, apply_trsf, intersect
from .visualization import plot_slice, stack_slices_pairwise, stack_slices_center
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import scanpy as sc
import pytest
from src.paste import intersect

test_dir = Path(__file__).parent
input_dir = test_dir / "data/input"
Expand All @@ -22,3 +23,15 @@ def slices():
slices.append(_slice)

return slices


@pytest.fixture(scope="session")
def intersecting_slices(slices):
common_genes = slices[0].var.index
for slice in slices[1:]:
common_genes = intersect(common_genes, slice.var.index)

for i in range(len(slices)):
slices[i] = slices[i][:, common_genes]

return slices
15 changes: 15 additions & 0 deletions tests/data/input/H_intermediate.csv

Large diffs are not rendered by default.

254 changes: 254 additions & 0 deletions tests/data/input/W_intermediate.csv

Large diffs are not rendered by default.

254 changes: 254 additions & 0 deletions tests/data/input/center_ot1_pairwise.csv

Large diffs are not rendered by default.

254 changes: 254 additions & 0 deletions tests/data/input/center_ot2_pairwise.csv

Large diffs are not rendered by default.

254 changes: 254 additions & 0 deletions tests/data/input/center_ot3_pairwise.csv

Large diffs are not rendered by default.

254 changes: 254 additions & 0 deletions tests/data/input/center_ot4_pairwise.csv

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions tests/data/output/H_center_NMF.csv

Large diffs are not rendered by default.

255 changes: 255 additions & 0 deletions tests/data/output/W_center_NMF.csv

Large diffs are not rendered by default.

255 changes: 255 additions & 0 deletions tests/data/output/center_ot1_pairwise.csv

Large diffs are not rendered by default.

255 changes: 255 additions & 0 deletions tests/data/output/center_ot2_pairwise.csv

Large diffs are not rendered by default.

255 changes: 255 additions & 0 deletions tests/data/output/center_ot3_pairwise.csv

Large diffs are not rendered by default.

255 changes: 255 additions & 0 deletions tests/data/output/center_ot4_pairwise.csv

Large diffs are not rendered by default.

73 changes: 70 additions & 3 deletions tests/test_paste.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import hashlib
from pathlib import Path

import numpy as np
import ot.backend
import pandas as pd
import tempfile

from src.paste import pairwise_align, center_align
from src.paste import pairwise_align, center_align, center_ot, intersect, center_NMF

test_dir = Path(__file__).parent
input_dir = test_dir / "data/input"
Expand Down Expand Up @@ -70,5 +73,69 @@ def test_center_alignment(slices):
for i, pi in enumerate(pairwise_info):
pd.DataFrame(
pi, index=center_slice.obs.index, columns=slices[i].obs.index
).to_csv(temp_dir / f"center_slice{i+1}_pairwise.csv")
assert_checksum_equals(temp_dir, f"center_slice{i+1}_pairwise.csv")
).to_csv(temp_dir / f"center_slice{i + 1}_pairwise.csv")
assert_checksum_equals(temp_dir, f"center_slice{i + 1}_pairwise.csv")


def test_center_ot(slices):
temp_dir = Path(tempfile.mkdtemp())

common_genes = slices[0].var.index
for slice in slices[1:]:
common_genes = intersect(common_genes, slice.var.index)

intersecting_slice = slices[0][:, common_genes]
pairwise_info, r = center_ot(
W=np.genfromtxt(input_dir / "W_intermediate.csv", delimiter=","),
H=np.genfromtxt(input_dir / "H_intermediate.csv", delimiter=","),
slices=slices,
center_coordinates=intersecting_slice.obsm["spatial"],
common_genes=common_genes,
use_gpu=False,
alpha=0.1,
backend=ot.backend.NumpyBackend(),
dissimilarity="kl",
norm=False,
G_inits=[None for _ in range(len(slices))],
)

expected_r = [
-25.08051355206619,
-26.139415232102213,
-25.728504876394076,
-25.740615316378296,
]

assert np.all(np.isclose(expected_r, r, rtol=1e-05, atol=1e-08, equal_nan=True))

for i, pi in enumerate(pairwise_info):
pd.DataFrame(
pi, index=intersecting_slice.obs.index, columns=slices[i].obs.index
).to_csv(temp_dir / f"center_ot{i + 1}_pairwise.csv")
assert_checksum_equals(temp_dir, f"center_ot{i + 1}_pairwise.csv")


def test_center_NMF(intersecting_slices):
temp_dir = Path(tempfile.mkdtemp())
n_slices = len(intersecting_slices)

pairwise_info = [
np.genfromtxt(input_dir / f"center_ot{i+1}_pairwise.csv", delimiter=",")
for i in range(n_slices)
]

_W, _H = center_NMF(
W=np.genfromtxt(input_dir / "W_intermediate.csv", delimiter=","),
H=np.genfromtxt(input_dir / "H_intermediate.csv", delimiter=","),
slices=intersecting_slices,
pis=pairwise_info,
lmbda=n_slices * [1.0 / n_slices],
n_components=15,
random_seed=0,
)

pd.DataFrame(_W).to_csv(temp_dir / "W_center_NMF.csv")
pd.DataFrame(_H).to_csv(temp_dir / "H_center_NMF.csv")
# TODO: The following computations seem to be architecture dependent (need to look into as for how)
# assert_checksum_equals(temp_dir, "W_center_NMF.csv")
# assert_checksum_equals(temp_dir, "H_center_NMF.csv")

0 comments on commit ed59977

Please sign in to comment.