Skip to content

Commit

Permalink
check encoder-ML method compatibility upon YAML parsing (#106)
Browse files Browse the repository at this point in the history
* check encoder-ML method compatibility upon YAML parsing

* changed class to KNN in tests

* move encoder imports to the check function

Co-authored-by: pavlovicmilena <milenapavlovic775@gmail.com>
  • Loading branch information
LonnekeScheffer and pavlovicmilena authored Oct 13, 2021
1 parent 29c9f27 commit 426cfe9
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 12 deletions.
9 changes: 8 additions & 1 deletion docs/source/developer_docs/how_to_add_new_ML_method.rst
Original file line number Diff line number Diff line change
Expand Up @@ -346,4 +346,11 @@ To run this from the root directory of the project, save the specification to sp

.. code-block:: console
immune-ml specs.yaml output_dir/
immune-ml specs.yaml output_dir/
Compatible encoders
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Each ML method is only compatible with a limited set of encoders. immuneML automatically checks if the given encoder and ML method are
compatible when running the TrainMLModel instruction, and raises an error if they are not compatible.
To ensure immuneML recognizes the encoder-ML method compatibility, make sure that the encoder(s) of interest is added to the list
of encoder classes returned by the :code:`get_compatible_encoders()` method of the ML method.
8 changes: 8 additions & 0 deletions docs/source/developer_docs/how_to_add_new_encoding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ To add a new encoder:
4. Implement the abstract methods :code:`encode()` and :code:`build_object()`.
5. Implement methods to import and export an encoder: :code:`get_additional_files()`, :code:`export_encoder()` and :code:`load_encoder()`, mostly relying on functionality already available in :py:obj:`~immuneML.encodings.DatasetEncoder.DatasetEncoder`.
6. Add class documentation including: what the encoder does, what the arguments are and an example on how to use it from YAML specification.
7. Add the new encoder class to the list of compatible encoders returned by the :code:`get_compatible_encoders()` method of the :ref:`MLMethod` of interest.

An example of the implementation of :code:`NewKmerFrequencyEncoder` for the :py:obj:`~immuneML.data_model.dataset.RepertoireDataset.RepertoireDataset` is shown.

Expand Down Expand Up @@ -353,3 +354,10 @@ This is the example of documentation for :py:obj:`~immuneML.encodings.filtered_s
p_value_threshold: 0.05
sequence_batch_size: 100000
repertoire_batch_size: 32
Compatible ML methods
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Each ML method is only compatible with a limited set of encoders. immuneML automatically checks if the given encoder and ML method are
compatible when running the TrainMLModel instruction, and raises an error if they are not compatible.
To ensure immuneML recognizes the encoder-ML method compatibility, make sure that the encoder is added to the list of encoder classes
returned by the :code:`get_compatible_encoders()` method of the ML method(s) of interest.
6 changes: 5 additions & 1 deletion immuneML/dsl/instruction_parsers/TrainMLModelParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,14 @@ def _parse_settings(self, instruction: dict, symbol_table: SymbolTable) -> list:
**symbol_table.get_config(setting["encoding"])["encoder_params"])\
.set_context({"dataset": symbol_table.get(instruction['dataset'])})

ml_method = symbol_table.get(setting["ml_method"])
ml_method.check_encoder_compatibility(encoder)

s = HPSetting(encoder=encoder,
encoder_name=setting["encoding"],
encoder_params=symbol_table.get_config(setting["encoding"])["encoder_params"],
ml_method=symbol_table.get(setting["ml_method"]), ml_method_name=setting["ml_method"],
ml_method=ml_method,
ml_method_name=setting["ml_method"],
ml_params=symbol_table.get_config(setting["ml_method"]),
preproc_sequence=preprocessing_sequence, preproc_sequence_name=preproc_name)
settings.append(s)
Expand Down
9 changes: 4 additions & 5 deletions immuneML/ml_methods/AtchleyKmerMILClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,8 @@ def _make_log_reg(self):

self.logistic_regression = PyTorchLogisticRegression(in_features=self.input_size, zero_abundance_weight_init=self.zero_abundance_weight_init)

def _check_encoded_data(self, encoded_data: EncodedData):
assert encoded_data.encoding == 'AtchleyKmerEncoder', f"AtchleyKmerMILClassifier: the encoding is not compatible with the given classifier. " \
f"Expected AtchleyKmer encoding, got {encoded_data.encoding} instead. "

def fit(self, encoded_data: EncodedData, label_name: str, cores_for_training: int = 2):
self.feature_names = encoded_data.feature_names
self._check_encoded_data(encoded_data)

Util.setup_pytorch(self.number_of_threads, self.random_seed)
self.input_size = encoded_data.examples.shape[1]
Expand Down Expand Up @@ -227,3 +222,7 @@ def get_classes(self) -> list:

def get_class_mapping(self) -> dict:
return self.class_mapping

def get_compatible_encoders(self):
from immuneML.encodings.atchley_kmer_encoding.AtchleyKmerEncoder import AtchleyKmerEncoder
return [AtchleyKmerEncoder]
4 changes: 3 additions & 1 deletion immuneML/ml_methods/DeepRC.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def _prepare_caching_params(self, encoded_data: EncodedData, type: str, label_na
("pytorch_device_name", self.pytorch_device_name))

def fit(self, encoded_data: EncodedData, label_name: str, cores_for_training: int = 2):
assert encoded_data.encoding == "DeepRCEncoder", f"DeepRC: ML method DeepRC is only compatible with the DeepRC encoder, found {encoded_data.encoding.replace('Encoder','')} encoder"
self.feature_names = encoded_data.feature_names
self._set_label_classes({label_name: encoded_data.labels[label_name]})
self.model = CacheHandler.memo_by_params(self._prepare_caching_params(encoded_data, "fit", label_name),
Expand Down Expand Up @@ -429,3 +428,6 @@ def get_class_mapping(self) -> dict:

def get_label(self) -> str:
return self.label

def get_compatible_encoders(self):
return [DeepRCEncoder]
9 changes: 9 additions & 0 deletions immuneML/ml_methods/KNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ def get_params(self):
def can_predict_proba(self) -> bool:
return True

def get_compatible_encoders(self):
from immuneML.encodings.distance_encoding.DistanceEncoder import DistanceEncoder
from immuneML.encodings.evenness_profile.EvennessProfileEncoder import EvennessProfileEncoder
from immuneML.encodings.filtered_sequence_encoding.SequenceAbundanceEncoder import SequenceAbundanceEncoder
from immuneML.encodings.kmer_frequency.KmerFrequencyEncoder import KmerFrequencyEncoder
from immuneML.encodings.onehot.OneHotEncoder import OneHotEncoder
from immuneML.encodings.word2vec.Word2VecEncoder import Word2VecEncoder
return [KmerFrequencyEncoder, OneHotEncoder, Word2VecEncoder, SequenceAbundanceEncoder, EvennessProfileEncoder, DistanceEncoder]

@staticmethod
def get_documentation():
doc = str(KNN.__doc__)
Expand Down
19 changes: 19 additions & 0 deletions immuneML/ml_methods/MLMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,22 @@ def can_predict_proba(self) -> bool:
def get_class_mapping(self) -> dict:
"""Returns a dictionary containing the mapping between label values and values internally used in the classifier"""
pass

@abc.abstractmethod
def get_compatible_encoders(self):
pass

def check_encoder_compatibility(self, encoder):
"""Checks whether the given encoder is compatible with this ML method, and throws an error if it is not."""
is_valid = False

for encoder_class in self.get_compatible_encoders():
if issubclass(encoder.__class__, encoder_class):
is_valid = True
break

if not is_valid:
raise ValueError(f"{encoder.__class__.__name__} is not compatible with ML Method {self.__class__.__name__}. "
f"Please use one of the following encoders instead: {', '.join([enc_class.__name__ for enc_class in self.get_compatible_encoders()])}")


4 changes: 4 additions & 0 deletions immuneML/ml_methods/ProbabilisticBinaryClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,7 @@ def get_classes(self) -> list:

def get_class_mapping(self) -> dict:
return self.class_mapping

def get_compatible_encoders(self):
from immuneML.encodings.filtered_sequence_encoding.SequenceAbundanceEncoder import SequenceAbundanceEncoder
return [SequenceAbundanceEncoder]
4 changes: 4 additions & 0 deletions immuneML/ml_methods/ReceptorCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,7 @@ def get_classes(self) -> list:

def get_class_mapping(self) -> dict:
return self.class_mapping

def get_compatible_encoders(self):
from immuneML.encodings.onehot.OneHotEncoder import OneHotEncoder
return [OneHotEncoder]
7 changes: 7 additions & 0 deletions immuneML/ml_methods/SklearnMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ def get_class_mapping(self) -> dict:
"""Returns a dictionary containing the mapping between label values and values internally used in the classifier"""
return self.class_mapping

def get_compatible_encoders(self):
from immuneML.encodings.evenness_profile.EvennessProfileEncoder import EvennessProfileEncoder
from immuneML.encodings.kmer_frequency.KmerFrequencyEncoder import KmerFrequencyEncoder
from immuneML.encodings.onehot.OneHotEncoder import OneHotEncoder
from immuneML.encodings.word2vec.Word2VecEncoder import Word2VecEncoder
return [KmerFrequencyEncoder, OneHotEncoder, Word2VecEncoder, EvennessProfileEncoder]

@staticmethod
def get_usage_documentation(model_name):
return f"""
Expand Down
3 changes: 3 additions & 0 deletions immuneML/ml_methods/TCRdistClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ def can_predict_proba(self) -> bool:
def _get_model_filename(self):
return "tcrdist_classifier"

def get_compatible_encoders(self):
from immuneML.encodings.distance_encoding.TCRdistEncoder import TCRdistEncoder
return [TCRdistEncoder]
6 changes: 3 additions & 3 deletions test/integration_tests/test_sequenceAbundanceEncoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def test_encoding(self):
},
"ml_methods": {
"knn": {
"KNN": {
"n_neighbors": 1
},
"KNN": {
"n_neighbors": 1
},
}
},
"reports": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_generate(self):
"knn": {
"KNN": {
"n_neighbors": 1
},
}
}
},
"reports": {
Expand Down

0 comments on commit 426cfe9

Please sign in to comment.