Skip to content

Commit f3d94b0

Browse files
author
FrancescaDr
committed
obs_names and var_names add to fields
1 parent f2a58ed commit f3d94b0

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

src/geome/ann2data/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _convert_to_tensor(self, obj):
8787
if obj.dtype.name == "category":
8888
return torch.from_numpy(pd.get_dummies(obj).to_numpy()).to(torch.float)
8989
if not np.issubdtype(obj.dtype, np.number):
90-
return torch.from_numpy(obj.astype(np.float)).to(torch.float)
90+
return torch.from_numpy(obj.astype(np.float64)).to(torch.float)
9191
if isinstance(obj, np.ndarray):
9292
return torch.from_numpy(obj).to(torch.float)
9393
else:

src/geome/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ def get_from_loc(adata: AnnData, location: str) -> Any:
2121
"""
2222
if location == "X":
2323
return adata.X
24+
elif location == "obs_names":
25+
return adata.obs_names.to_numpy()
26+
elif location == "var_names":
27+
return adata.var_names.to_numpy()
28+
2429
assert len(location.split("/")) == 2, f"Location must have only one delimiter {location}"
2530
axis, key = location.split("/")
2631

tests/ann2data/test_ann2data_by_category.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,22 @@ def test_sample_case_ann2data_basic():
1010
# make sure that there are two clusters of spatial coordinates
1111
# so that the resulting splits number of edges will be the same
1212
# as the sum of the number of edges in each cluster
13-
func_args = {"radius": 4.0, "coord_type": "generic"}
13+
func_args = {"radius": 4.0, "coord_type": "generic", "library_key": "image_id"}
1414
coordinates[:25, 0] += 100
1515
adata_gt = ad.AnnData(
1616
np.random.rand(50, 2),
17-
obs={"cell_type": ["a"] * 25 + ["b"] * 25, "image_id": list("cd" * 25)},
17+
obs={"cell_type": ["a"] * 20 + ["b"] * 20 + ["c"] * 5 + ["d"] * 5, "image_id": list("xy" * 20) + ["z"] * 10},
1818
obsm={"spatial_init": coordinates},
1919
)
2020
a2d = ann2data.Ann2DataByCategory(
21-
fields={"x": ["X"], "edge_index": ["uns/edge_index"], "edge_weight": ["uns/edge_weight"]},
21+
fields={
22+
"x": ["X"],
23+
"obs_names": ["obs_names"],
24+
"var_names": ["var_names"],
25+
"edge_index": ["uns/edge_index"],
26+
"edge_weight": ["uns/edge_weight"],
27+
"y": ["obs/cell_type"],
28+
},
2229
category="cell_type",
2330
preprocess=transforms.Categorize(keys=["cell_type", "image_id"]),
2431
transform=transforms.AddEdgeIndex(
@@ -30,7 +37,7 @@ def test_sample_case_ann2data_basic():
3037
),
3138
)
3239
datas = list(a2d(adata_gt.copy()))
33-
assert len(datas) == 2
40+
assert len(datas) == 3
3441
big_adata_tf = transforms.Compose(
3542
[
3643
transforms.Categorize(keys=["cell_type", "image_id"]),

0 commit comments

Comments
 (0)