Skip to content

Commit

Permalink
Added utility func for EncoderHelper to automatically encode labels f…
Browse files Browse the repository at this point in the history
…or datasets
  • Loading branch information
LonnekeScheffer committed Mar 19, 2024
1 parent 5570ff6 commit 452bd37
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
8 changes: 5 additions & 3 deletions immuneML/encodings/deeprc/DeepRCEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from immuneML.data_model.encoded_data.EncodedData import EncodedData
from immuneML.encodings.DatasetEncoder import DatasetEncoder
from immuneML.encodings.EncoderParams import EncoderParams
from immuneML.util.EncoderHelper import EncoderHelper
from immuneML.util.PathBuilder import PathBuilder


Expand Down Expand Up @@ -88,11 +89,12 @@ def encode(self, dataset, params: EncoderParams) -> RepertoireDataset:

self.export_repertoire_tsv_files(result_path)

labels = params.label_config.get_labels_by_name()
metadata_filepath = self.export_metadata_file(dataset, labels, result_path)

metadata_filepath = self.export_metadata_file(dataset, params.label_config.get_labels_by_name(), result_path)

encoded_dataset = dataset.clone()
encoded_dataset.encoded_data = EncodedData(examples=None, labels=dataset.get_metadata(labels) if params.encode_labels else None,
encoded_dataset.encoded_data = EncodedData(examples=None,
labels=EncoderHelper.encode_dataset_labels(dataset, params.label_config, params.encode_labels),
example_ids=dataset.get_repertoire_ids(),
example_weights=dataset.get_example_weights(),
encoding=DeepRCEncoder.__name__,
Expand Down
24 changes: 20 additions & 4 deletions immuneML/util/EncoderHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ def check_dataset_type_available_in_mapping(dataset, class_name):

@staticmethod
def encode_element_dataset_labels(dataset: ElementDataset, label_config: LabelConfiguration):
'''
Automatically generates the encoded labels for an ElementDataset (= SequenceDataset or ReceptorDataset)
'''

'''Automatically generates the encoded labels for an ElementDataset (= SequenceDataset or ReceptorDataset)'''
labels = {name: [] for name in label_config.get_labels_by_name()}

for sequence in dataset.get_data():
Expand All @@ -84,6 +81,25 @@ def encode_element_dataset_labels(dataset: ElementDataset, label_config: LabelCo

return labels

@staticmethod
def encode_repertoire_dataset_labels(dataset: RepertoireDataset, label_config: LabelConfiguration):
'''Automatically generates the encoded labels for a RepertoireDataset'''
label_names = label_config.get_labels_by_name()
return dataset.get_metadata(label_names)

@staticmethod
def encode_dataset_labels(dataset: Dataset, label_config: LabelConfiguration, encode_labels: bool = True):
'''Automatically generates the encoded labels for a Dataset.
This contains labels in the following format: {'label_name': ['label_class1', 'label_class2', 'label_class2']}
where the inner list(s) contain the class label for each example in the dataset'''
if not encode_labels:
return None

if isinstance(dataset, RepertoireDataset):
return EncoderHelper.encode_repertoire_dataset_labels(dataset, label_config)
else:
return EncoderHelper.encode_element_dataset_labels(dataset, label_config)

@staticmethod
def check_positive_class_labels(label_config: LabelConfiguration, location: str):
'''
Expand Down

0 comments on commit 452bd37

Please sign in to comment.