|
| 1 | +import pickle |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from gcl_classifier.labels import baden_cluster_id_to_group_id, baden_group_id_to_supergroup, BADEN_CLUSTER_INFO |
| 6 | + |
| 7 | + |
| 8 | +def classify_cells(preproc_chirps, preproc_bars, bar_ds_pvalues, roi_size_um2s, |
| 9 | + chirp_features, bar_features, classifier): |
| 10 | + features, feature_names = extract_features( |
| 11 | + preproc_chirps, preproc_bars, bar_ds_pvalues, roi_size_um2s, chirp_features, bar_features) |
| 12 | + probs = classifier.predict_proba(features) |
| 13 | + return probs |
| 14 | + |
| 15 | + |
| 16 | +def baden16_cluster_probs_to_info(probs): |
| 17 | + if len(probs) != 75: |
| 18 | + raise ValueError(f"Expected 75 probabilities corresponding to 75 Baden clusters, got {len(probs)}.") |
| 19 | + |
| 20 | + cluster_id = np.argmax(probs) + 1 # Cluster IDs are 1-indexed |
| 21 | + group_id = baden_cluster_id_to_group_id(cluster_id) |
| 22 | + supergroup = baden_group_id_to_supergroup(group_id) |
| 23 | + prob_cluster = probs[cluster_id - 1] |
| 24 | + |
| 25 | + group_ids = BADEN_CLUSTER_INFO[:, 2].astype(int) |
| 26 | + supergroups = BADEN_CLUSTER_INFO[:, 3].astype(str) |
| 27 | + |
| 28 | + prob_group = np.sum(probs[group_ids == group_id]) |
| 29 | + prob_supergroup = np.sum(probs[supergroups == supergroup]) |
| 30 | + prob_rgc = np.sum(probs[supergroups != 'dAC']) |
| 31 | + prob_class = (1. - prob_rgc) if supergroup == 'dAC' else prob_rgc |
| 32 | + |
| 33 | + return cluster_id, group_id, supergroup, prob_cluster, prob_group, prob_supergroup, prob_class |
| 34 | + |
| 35 | + |
| 36 | +def extract_features( |
| 37 | + preproc_chirps, |
| 38 | + preproc_bars, |
| 39 | + bar_ds_pvalues, |
| 40 | + roi_size_um2s, |
| 41 | + chirp_features, |
| 42 | + bar_features, |
| 43 | + ) -> tuple[np.ndarray, list[str]]: |
| 44 | + """ |
| 45 | + Transforms the preprocessed chirps and bars using the provided chirp/bar features. |
| 46 | + Concatenates the results with the bar_ds_pvalues and roi_sizes, and returns them together with feature names. |
| 47 | + The result can be used as input to the classifier. |
| 48 | + """ |
| 49 | + features = np.concatenate([ |
| 50 | + np.dot(preproc_chirps, chirp_features), |
| 51 | + np.dot(preproc_bars, bar_features), |
| 52 | + bar_ds_pvalues[:, np.newaxis], |
| 53 | + roi_size_um2s[:, np.newaxis] |
| 54 | + ], axis=-1) |
| 55 | + |
| 56 | + feature_names = [f'chirp_{i}' for i in range(chirp_features.shape[1])] + \ |
| 57 | + [f'bar_{i}' for i in range(bar_features.shape[1])] + ['bar_ds_pvalue', 'roi_size_um2'] |
| 58 | + |
| 59 | + return features, feature_names |
| 60 | + |
| 61 | + |
| 62 | +def check_classifier_dict(clf_dict: dict) -> dict: |
| 63 | + assert type(clf_dict) == dict, "Classifier file must contain a dictionary with classifier data." |
| 64 | + |
| 65 | + # Check keys |
| 66 | + assert 'classifier' in clf_dict, "Classifier dictionary must contain a 'classifier' key." |
| 67 | + assert 'chirp_feats' in clf_dict, "Classifier dictionary must contain a 'chirp_feats' key." |
| 68 | + assert 'bar_feats' in clf_dict, "Classifier dictionary must contain a 'bar_feats' key." |
| 69 | + assert 'feature_names' in clf_dict, "Classifier dictionary must contain a 'feature_names' key." |
| 70 | + assert 'train_x' in clf_dict, "Classifier dictionary must contain a 'train_x' key." |
| 71 | + assert 'train_y' in clf_dict, "Classifier dictionary must contain a 'train_y' key." |
| 72 | + assert 'y_names' in clf_dict, "Classifier dictionary must contain a 'y_names' key." |
| 73 | + |
| 74 | + # Chek value |
| 75 | + assert isinstance(clf_dict['train_x'], np.ndarray), "The 'train_x' key must contain a numpy array." |
| 76 | + assert isinstance(clf_dict['train_y'], np.ndarray), "The 'train_y' key must contain a numpy array." |
| 77 | + assert clf_dict['train_x'].shape[0] == clf_dict[ |
| 78 | + 'train_y'].size, "The number of samples in 'train_x' and 'train_y' must match." |
| 79 | + |
| 80 | + for val in np.unique(clf_dict['train_y']): |
| 81 | + assert val in clf_dict['y_names'].keys(), f"Value {val} in 'train_y' not found in 'y_names'." |
| 82 | + |
| 83 | + # Check if classifier is a valid scikit-learn classifier |
| 84 | + from sklearn.base import is_classifier |
| 85 | + assert is_classifier(clf_dict['classifier']), "The 'classifier' key must contain a valid scikit-learn classifier." |
| 86 | + |
| 87 | + return clf_dict |
| 88 | + |
| 89 | + |
| 90 | +def save_classifier_and_data(classifier, chirp_feats, bar_feats, feature_names, train_x, train_y, y_names, |
| 91 | + classifier_file, **kwargs) -> None: |
| 92 | + """ |
| 93 | + Saves the classifier and its metadata to a file. |
| 94 | + """ |
| 95 | + clf_dict = { |
| 96 | + 'classifier': classifier, |
| 97 | + 'chirp_feats': chirp_feats, |
| 98 | + 'bar_feats': bar_feats, |
| 99 | + 'feature_names': feature_names, |
| 100 | + 'train_x': train_x, |
| 101 | + 'train_y': train_y, |
| 102 | + 'y_names': y_names, |
| 103 | + **kwargs |
| 104 | + } |
| 105 | + |
| 106 | + check_classifier_dict(clf_dict) |
| 107 | + |
| 108 | + with open(classifier_file, 'wb') as f: |
| 109 | + pickle.dump(clf_dict, f) |
0 commit comments