Skip to content

Commit

Permalink
Allow user to dynamically set batch size of pixel SOM assignments (#1069
Browse files Browse the repository at this point in the history
)

* Ensure default batch size is efficient for big machines and large cohorts, and also that users can change if necessary

* Update parameter ordering

* Clarify param passed into abstract ClusterHelpers class

* Add num_parallel_cells argument to cell SOM cluster label assignment

* Set up testing with GitHub actions 3.11
  • Loading branch information
alex-l-kong authored Oct 30, 2023
1 parent ad859b2 commit bfc53b7
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "3.x"
python-version: "3.11"
cache-dependency-path: "**/pyproject.toml"
cache: "pip"
- name: Check the Example Dataset Cache
Expand Down
7 changes: 5 additions & 2 deletions src/ark/phenotyping/cell_som_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def train_cell_som(fovs, base_dir, cell_table_path, cell_som_cluster_cols,
return cell_pysom


def cluster_cells(base_dir, cell_pysom, cell_som_cluster_cols, overwrite=False):
def cluster_cells(base_dir, cell_pysom, cell_som_cluster_cols, num_parallel_cells=1000000,
overwrite=False):
"""Uses trained SOM weights to assign cluster labels on full cell data.
Saves data with cluster labels to `cell_cluster_name`.
Expand All @@ -87,6 +88,8 @@ def cluster_cells(base_dir, cell_pysom, cell_som_cluster_cols, overwrite=False):
The SOM cluster object containing the cell SOM weights
cell_som_cluster_cols (list):
The list of columns used for SOM training
num_parallel_cells (int):
How many cells to label in parallel at once
overwrite (bool):
If set, overwrites the SOM cluster assignments if they exist
Expand Down Expand Up @@ -131,7 +134,7 @@ def cluster_cells(base_dir, cell_pysom, cell_som_cluster_cols, overwrite=False):

# run the trained SOM on the dataset, assigning clusters
print("Mapping cell data to SOM cluster labels")
cell_data_som_labels = cell_pysom.assign_som_clusters()
cell_data_som_labels = cell_pysom.assign_som_clusters(num_parallel_cells)

return cell_data_som_labels

Expand Down
36 changes: 27 additions & 9 deletions src/ark/phenotyping/cluster_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,24 @@ def train_som(self, data: pd.DataFrame):
# save the weights to weights_path
feather.write_dataframe(self.weights, self.weights_path, compression='uncompressed')

def generate_som_clusters(self, external_data: pd.DataFrame) -> np.ndarray:
def generate_som_clusters(self, external_data: pd.DataFrame,
num_parallel_obs: int = 1000000) -> np.ndarray:
"""Uses the weights to generate SOM clusters for a dataset
Args:
external_data (pandas.DataFrame):
The dataset to generate SOM clusters for
num_parallel_obs (int):
Partition size of `external_data` for assigning SOM labels
Returns:
numpy.ndarray:
The SOM clusters generated for each pixel in `external_data`
"""
# ensure batch_size passed is valid
if num_parallel_obs <= 0:
raise ValueError("num_parallel_obs specified needs to be greater than 0")

# subset on just the weights columns prior to SOM cluster mapping
weights_cols = self.weights.columns.values

Expand All @@ -138,14 +145,14 @@ def generate_som_clusters(self, external_data: pd.DataFrame) -> np.ndarray:
# define the batches of cluster labels assigned
cluster_labels = []

# work in batches of 100 to account to support large dataframe sizes
# work in batches to support large dataframe sizes
# TODO: possible dynamic computation in order?
for i in np.arange(0, external_data.shape[0], 100):
for i in np.arange(0, external_data.shape[0], num_parallel_obs):
# NOTE: this also orders the columns of external_data_sub the same as self.weights
cluster_labels.append(map_data_to_nodes(
self.weights.values.astype(np.float64),
external_data.loc[
i:min(i + 99, external_data.shape[0]), weights_cols
i:min(i + num_parallel_obs - 1, external_data.shape[0]), weights_cols
].values.astype(np.float64)
)[0])

Expand Down Expand Up @@ -257,7 +264,9 @@ def train_som(self, overwrite=False):

super().train_som(self.train_data[self.columns])

def assign_som_clusters(self, external_data: pd.DataFrame, normalize_data: bool = True) -> pd.DataFrame:
def assign_som_clusters(self, external_data: pd.DataFrame,
normalize_data: bool = True,
num_parallel_pixels: int = 1000000) -> pd.DataFrame:
"""Assigns SOM clusters using `weights` to a dataset
Args:
Expand All @@ -266,14 +275,19 @@ def assign_som_clusters(self, external_data: pd.DataFrame, normalize_data: bool
normalize_data (bool):
Whether or not to normalize `external_data`.
Flag needed to prevent re-normalization of normalized dataset.
num_parallel_pixels (int):
Partition size of `external_data` for assigning SOM labels
Returns:
pandas.DataFrame:
The dataset with the SOM clusters assigned.
"""
# normalize external_data prior to assignment, if normalize_data set
external_data_norm = self.normalize_data(external_data) if normalize_data else external_data.copy()
som_labels = super().generate_som_clusters(external_data_norm)
external_data_norm = self.normalize_data(external_data) if normalize_data \
else external_data.copy()
som_labels = super().generate_som_clusters(
external_data_norm, num_parallel_obs=num_parallel_pixels
)

# assign SOM clusters to external_data
external_data_norm['pixel_som_cluster'] = som_labels
Expand Down Expand Up @@ -372,19 +386,23 @@ def train_som(self, overwrite=False):

super().train_som(self.cell_data[self.columns])

def assign_som_clusters(self) -> pd.DataFrame:
def assign_som_clusters(self, num_parallel_cells=1000000) -> pd.DataFrame:
"""Assigns SOM clusters using `weights` to `cell_data`
Args:
external_data (pandas.DataFrame):
The dataset to assign SOM clusters to
num_parallel_cells (int):
Partition size of `self.cell_data` for assigning SOM labels
Returns:
pandas.DataFrame:
`cell_data` with the SOM clusters assigned.
"""
# cell_data is already normalized, don't repeat
som_labels = super().generate_som_clusters(self.cell_data[self.columns])
som_labels = super().generate_som_clusters(
self.cell_data[self.columns], num_parallel_obs=num_parallel_cells
)

# assign SOM clusters to cell_data
self.cell_data['cell_som_cluster'] = som_labels
Expand Down
15 changes: 11 additions & 4 deletions src/ark/phenotyping/pixel_som_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def train_pixel_som(fovs, channels, base_dir,
return pixel_pysom


def run_pixel_som_assignment(pixel_data_path, pixel_pysom_obj, overwrite, fov):
def run_pixel_som_assignment(pixel_data_path, pixel_pysom_obj, overwrite, num_parallel_pixels, fov):
"""Helper function to assign pixel SOM cluster labels
Args:
Expand All @@ -100,6 +100,8 @@ def run_pixel_som_assignment(pixel_data_path, pixel_pysom_obj, overwrite, fov):
The pixel SOM cluster object
overwrite (bool):
Whether to overwrite the pixel SOM clusters or not
num_parallel_pixels (int):
How many pixels to label in parallel at once for each FOV
fov (str):
The name of the FOV to process
Expand All @@ -123,7 +125,9 @@ def run_pixel_som_assignment(pixel_data_path, pixel_pysom_obj, overwrite, fov):
fov_data = fov_data.drop(columns="pixel_som_cluster")

# assign the SOM labels to fov_data, overwrite flag indicates if data needs normalization
fov_data = pixel_pysom_obj.assign_som_clusters(fov_data, normalize_data=not overwrite)
fov_data = pixel_pysom_obj.assign_som_clusters(
fov_data, normalize_data=not overwrite, num_parallel_pixels=num_parallel_pixels
)

# resave the data with the SOM cluster labels assigned
temp_path = os.path.join(pixel_data_path + '_temp', fov + '.feather')
Expand All @@ -133,7 +137,8 @@ def run_pixel_som_assignment(pixel_data_path, pixel_pysom_obj, overwrite, fov):


def cluster_pixels(fovs, channels, base_dir, pixel_pysom, data_dir='pixel_mat_data',
multiprocess=False, batch_size=5, overwrite=False):
multiprocess=False, batch_size=5, num_parallel_pixels=1000000,
overwrite=False):
"""Uses trained SOM weights to assign cluster labels on full pixel data.
Saves data with cluster labels to `data_dir`.
Expand All @@ -153,6 +158,8 @@ def cluster_pixels(fovs, channels, base_dir, pixel_pysom, data_dir='pixel_mat_da
Whether to use multiprocessing or not
batch_size (int):
The number of FOVs to process in parallel, ignored if `multiprocess` is `False`
num_parallel_pixels (int):
How many pixels to label in parallel at once for each FOV
overwrite (bool):
If set, force overwrite the SOM labels in all the FOVs
"""
Expand Down Expand Up @@ -242,7 +249,7 @@ def cluster_pixels(fovs, channels, base_dir, pixel_pysom, data_dir='pixel_mat_da

# define the partial function to iterate over
fov_data_func = partial(
run_pixel_som_assignment, data_path, pixel_pysom, overwrite
run_pixel_som_assignment, data_path, pixel_pysom, overwrite, num_parallel_pixels
)

# use the som weights to assign SOM cluster values to data in data_dir
Expand Down
32 changes: 27 additions & 5 deletions tests/phenotyping/cluster_helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ def test_train_som_new_cols(self):
self.pixel_pysom_nonweights.train_data = old_train_data
self.pixel_pysom_nonweights.columns = old_columns

def test_assign_som_clusters(self):
@parametrize("num_parallel_pixels", [10, 10000])
def test_assign_som_clusters(self, num_parallel_pixels):
# generate sample external data
# NOTE: test on shuffled data to ensure column matching
col_shuffle = deepcopy(self.pixel_pysom_nonweights.columns)
Expand All @@ -380,7 +381,9 @@ def test_assign_som_clusters(self):
)

# assign SOM labels to sample_external_data
som_label_data = self.pixel_pysom_nonweights.assign_som_clusters(sample_external_data)
som_label_data = self.pixel_pysom_nonweights.assign_som_clusters(
sample_external_data, num_parallel_pixels=num_parallel_pixels
)

# assert the som labels were assigned and they are all in the range 1 to 200
assert 'pixel_som_cluster' in som_label_data.columns.values
Expand All @@ -389,7 +392,7 @@ def test_assign_som_clusters(self):

# test normalize_data flag, shouldn't assign different SOM labels on same dataset
som_label_data_no_norm = self.pixel_pysom_nonweights.assign_som_clusters(
som_label_data, normalize_data=False
som_label_data, num_parallel_pixels=num_parallel_pixels, normalize_data=False
)

# test that no additional normalization added to som_label_data_no_norm
Expand All @@ -400,6 +403,22 @@ def test_assign_som_clusters(self):
new_som_clusters = som_label_data_no_norm['pixel_som_cluster'].values
assert np.all(som_clusters == new_som_clusters)

def test_assign_som_clusters_bad(self):
col_shuffle = deepcopy(self.pixel_pysom_nonweights.columns)
random.shuffle(col_shuffle)
meta_cols = ['fov', 'row_index', 'column_index', 'label']
sample_external_data = pd.DataFrame(
np.random.rand(1000, 10),
columns=col_shuffle + meta_cols
)

# assign SOM labels to sample_external_data
with pytest.raises(ValueError):
som_label_data = self.pixel_pysom_nonweights.assign_som_clusters(
sample_external_data, num_parallel_pixels=0,
normalize_data=False
)


class TestCellSOMCluster:
@pytest.fixture(autouse=True, scope="function")
Expand Down Expand Up @@ -485,9 +504,12 @@ def test_train_som_new_cols(self):
self.cell_pysom_nonweights.cell_data = old_cell_data
self.cell_pysom_nonweights.columns = old_columns

def test_assign_som_clusters(self):
@parametrize("num_parallel_cells", [10, 10000])
def test_assign_som_clusters(self, num_parallel_cells):
# generate SOM cluster values for cell_data
som_label_data = self.cell_pysom_nonweights.assign_som_clusters()
som_label_data = self.cell_pysom_nonweights.assign_som_clusters(
num_parallel_cells=num_parallel_cells
)

# assert the som labels were assigned and they are all in the range 1 to 200
assert 'cell_som_cluster' in som_label_data.columns.values
Expand Down
4 changes: 2 additions & 2 deletions tests/phenotyping/pixel_som_clustering_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_run_pixel_som_assignment():
)

fov_status = pixel_som_clustering.run_pixel_som_assignment(
os.path.join(temp_dir, 'pixel_mat_data'), sample_pixel_cc, False, 'fov0'
os.path.join(temp_dir, 'pixel_mat_data'), sample_pixel_cc, False, 10, 'fov0'
)

# assert the fov returned is fov0 and the status is 0
Expand All @@ -82,7 +82,7 @@ def test_run_pixel_som_assignment():

# attempt to run remapping for fov1
fov_status = pixel_som_clustering.run_pixel_som_assignment(
os.path.join(temp_dir, 'pixel_mat_data'), sample_pixel_cc, False, 'fov1'
os.path.join(temp_dir, 'pixel_mat_data'), sample_pixel_cc, False, 10, 'fov1'
)

# assert the fov returned is fov1 and the status is 1
Expand Down

0 comments on commit bfc53b7

Please sign in to comment.