@@ -10,15 +10,22 @@ def test_sample_case_ann2data_basic():
10
10
# make sure that there are two clusters of spatial coordinates
11
11
# so that the resulting splits number of edges will be the same
12
12
# as the sum of the number of edges in each cluster
13
- func_args = {"radius" : 4.0 , "coord_type" : "generic" }
13
+ func_args = {"radius" : 4.0 , "coord_type" : "generic" , "library_key" : "image_id" }
14
14
coordinates [:25 , 0 ] += 100
15
15
adata_gt = ad .AnnData (
16
16
np .random .rand (50 , 2 ),
17
- obs = {"cell_type" : ["a" ] * 25 + ["b" ] * 25 , "image_id" : list ("cd " * 25 ) },
17
+ obs = {"cell_type" : ["a" ] * 20 + ["b" ] * 20 + [ "c" ] * 5 + [ "d" ] * 5 , "image_id" : list ("xy " * 20 ) + [ "z" ] * 10 },
18
18
obsm = {"spatial_init" : coordinates },
19
19
)
20
20
a2d = ann2data .Ann2DataByCategory (
21
- fields = {"x" : ["X" ], "edge_index" : ["uns/edge_index" ], "edge_weight" : ["uns/edge_weight" ]},
21
+ fields = {
22
+ "x" : ["X" ],
23
+ "obs_names" : ["obs_names" ],
24
+ "var_names" : ["var_names" ],
25
+ "edge_index" : ["uns/edge_index" ],
26
+ "edge_weight" : ["uns/edge_weight" ],
27
+ "y" : ["obs/cell_type" ],
28
+ },
22
29
category = "cell_type" ,
23
30
preprocess = transforms .Categorize (keys = ["cell_type" , "image_id" ]),
24
31
transform = transforms .AddEdgeIndex (
@@ -30,7 +37,7 @@ def test_sample_case_ann2data_basic():
30
37
),
31
38
)
32
39
datas = list (a2d (adata_gt .copy ()))
33
- assert len (datas ) == 2
40
+ assert len (datas ) == 3
34
41
big_adata_tf = transforms .Compose (
35
42
[
36
43
transforms .Categorize (keys = ["cell_type" , "image_id" ]),
0 commit comments