Skip to content

Commit

Permalink
Merge pull request #47 from theislab/refactor
Browse files Browse the repository at this point in the history
Adding tests and interface improvements
  • Loading branch information
selmanozleyen authored Apr 30, 2024
2 parents 74b5795 + 574fd97 commit 991a313
Show file tree
Hide file tree
Showing 26 changed files with 613 additions and 454 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ __pycache__/
/.idea/
/.vscode/

**/lightning_logs/**
**/lightning_logs/**
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ pip install geome
1. Install the latest development version:

```bash
mamba create -n geome python=3.11
mamba activate geome
mamba install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
pip install git+https://github.com/theislab/geome.git@main
```

Expand Down
5 changes: 4 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@


# -- Project information -----------------------------------------------------

# NOTE: If you installed your project in editable mode, this might be stale.
# If this is the case, reinstall it to refresh the metadata
# info = metadata("geome")
# project_name = info["Name"]
project_name = "geome"
Expand Down Expand Up @@ -114,12 +115,14 @@
#
html_theme = "furo"
html_static_path = ["_static"]
html_css_files = ["css/custom.css"]
html_title = project_name

html_theme_options = {
"repository_url": repository_url,
"use_repository_button": True,
"path_to_docs": "docs/",
"navigation_with_keys": False,
}

pygments_style = "default"
Expand Down
34 changes: 0 additions & 34 deletions docs/conf.py.rej

This file was deleted.

8 changes: 0 additions & 8 deletions docs/index.md.rej

This file was deleted.

15 changes: 9 additions & 6 deletions docs/notebooks/1_iterables_and_iterators.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@
"metadata": {},
"outputs": [],
"source": [
"from geome import iterables, transforms, ann2data\n",
"from geome import iterables, ann2data\n",
"import squidpy as sq\n",
"import numpy as np\n",
"from anndata import AnnData"
"import numpy as np"
]
},
{
Expand Down Expand Up @@ -162,7 +161,7 @@
],
"source": [
"split_adatas = list(to_iterable(adata)) # split by cluster\n",
"assert len(split_adatas) == len(adata.obs[\"Cluster\"].cat.categories) # ensure all clusters have their own adata\n",
"assert len(split_adatas) == len(adata.obs[\"Cluster\"].cat.categories) # ensure all clusters have their own adata\n",
"split_adatas[:3] # show first 3"
]
},
Expand All @@ -183,7 +182,9 @@
"metadata": {},
"outputs": [],
"source": [
"assert all(len(ad.obs[\"Cluster\"].cat.categories) == len(adata.obs[\"Cluster\"].cat.categories) for ad in split_adatas) # ensure all splits have the same category"
"assert all(\n",
" len(ad.obs[\"Cluster\"].cat.categories) == len(adata.obs[\"Cluster\"].cat.categories) for ad in split_adatas\n",
") # ensure all splits have the same category"
]
},
{
Expand Down Expand Up @@ -297,7 +298,9 @@
"metadata": {},
"outputs": [],
"source": [
"assert all(np.allclose(r1.x, r2.x) for r1, r2 in zip(result1, result2)) and all(np.allclose(r1.x, r3.x) for r1, r3 in zip(result1, result3))"
"assert all(np.allclose(r1.x, r2.x) for r1, r2 in zip(result1, result2)) and all(\n",
" np.allclose(r1.x, r3.x) for r1, r3 in zip(result1, result3)\n",
")"
]
}
],
Expand Down
83 changes: 35 additions & 48 deletions docs/notebooks/2_transforms_and_preprocessing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,22 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from geome import iterables, transforms, ann2data\n",
"from geome import transforms\n",
"import squidpy as sq\n",
"import numpy as np\n",
"from anndata import AnnData"
]
},
Expand All @@ -76,19 +66,12 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Load some simple adata sample\n",
"def create_adata():\n",
" adata = sq.datasets.mibitof()\n",
" simple_adata = AnnData(adata.X)\n",
" simple_adata.obs[\"Cluster\"] = adata.obs[\"Cluster\"]\n",
" simple_adata.obsp[\"connectivities\"] = adata.obsp[\"connectivities\"]\n",
" return simple_adata\n",
"\n",
"adata = create_adata()"
"adata = sq.datasets.mibitof()"
]
},
{
Expand All @@ -102,18 +85,21 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AnnData object with n_obs × n_vars = 3309 × 36\n",
" obs: 'Cluster'\n",
" obsp: 'connectivities'"
" obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size', 'category', 'donor', 'Cluster', 'batch', 'library_id'\n",
" var: 'mean-0', 'std-0', 'mean-1', 'std-1', 'mean-2', 'std-2'\n",
" uns: 'Cluster_colors', 'batch_colors', 'neighbors', 'spatial', 'umap'\n",
" obsm: 'X_scanorama', 'X_umap', 'spatial'\n",
" obsp: 'connectivities', 'distances'"
]
},
"execution_count": 12,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -124,13 +110,12 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"adds_edge_index = transforms.AddEdgeIndex(\n",
" adj_matrix_loc=\"obsp/connectivities\", edge_index_key=\"edge_index\", overwrite=True\n",
")"
"sq_neighbors_args = {\"radius\": 4.0, \"coord_type\": \"generic\"}\n",
"adds_edge_index = transforms.AddEdgeIndex(edge_index_key=\"edge_index\", edge_weight_key=\"edge_weight\", func_args=sq_neighbors_args, spatial_key=\"spatial\", key_added=\"added\")"
]
},
{
Expand All @@ -142,19 +127,21 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AnnData object with n_obs × n_vars = 3309 × 36\n",
" obs: 'Cluster'\n",
" uns: 'edge_index'\n",
" obsp: 'connectivities'"
" obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size', 'category', 'donor', 'Cluster', 'batch', 'library_id'\n",
" var: 'mean-0', 'std-0', 'mean-1', 'std-1', 'mean-2', 'std-2'\n",
" uns: 'Cluster_colors', 'batch_colors', 'neighbors', 'spatial', 'umap', 'added_neighbors', 'edge_index', 'edge_weight'\n",
" obsm: 'X_scanorama', 'X_umap', 'spatial'\n",
" obsp: 'connectivities', 'distances', 'added_connectivities', 'added_distances'"
]
},
"execution_count": 14,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -165,42 +152,42 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"multiple_transforms = transforms.Compose( # you can get creative with this\n",
" [\n",
" transforms.AddEdgeIndex(adj_matrix_loc=\"obsp/connectivities\", edge_index_key=\"edge_index\", overwrite=True),\n",
" transforms.AddEdgeWeight(\n",
" weight_matrix_loc=\"obsp/connectivities\",\n",
" edge_index_key=\"edge_index\",\n",
" edge_weight_key=\"edge_weight\",\n",
" overwrite=True,\n",
" ),\n",
" transforms.AddAdjMatrix(func_args=sq_neighbors_args, key_added=\"added2\", spatial_key=\"spatial\"),\n",
" transforms.AddEdgeIndexFromAdj(adj_matrix_loc=\"obsp/added2_connectivities\", edge_index_key=\"edge_index2\", edge_weight_key=\"edge_weight2\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([71500])"
"AnnData object with n_obs × n_vars = 3309 × 36\n",
" obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size', 'category', 'donor', 'Cluster', 'batch', 'library_id'\n",
" var: 'mean-0', 'std-0', 'mean-1', 'std-1', 'mean-2', 'std-2'\n",
" uns: 'Cluster_colors', 'batch_colors', 'neighbors', 'spatial', 'umap', 'added_neighbors', 'edge_index', 'edge_weight', 'added2_neighbors', 'edge_index2', 'edge_weight2'\n",
" obsm: 'X_scanorama', 'X_umap', 'spatial'\n",
" obsp: 'connectivities', 'distances', 'added_connectivities', 'added_distances', 'added2_connectivities', 'added2_distances'"
]
},
"execution_count": 16,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res = multiple_transforms(adata)\n",
"res.uns[\"edge_weight\"].shape"
"res"
]
}
],
Expand Down
Loading

0 comments on commit 991a313

Please sign in to comment.