diff --git a/mhcflurry/fasta.py b/mhcflurry/fasta.py index 88b87092b..e0f97eb64 100644 --- a/mhcflurry/fasta.py +++ b/mhcflurry/fasta.py @@ -16,13 +16,28 @@ import pandas -def read_fasta_to_dataframe(filename): +def read_fasta_to_dataframe(filename, full_descriptions=False): + """ + Parse a fasta file to a pandas DataFrame. + + Parameters + ---------- + filename : string + full_descriptions : bool + If true, instead of returning sequence IDs (the first space-separated + token), return the full description associated with each record. + Returns + ------- + pandas.DataFrame with columns "sequence_id" and "sequence". + """ reader = FastaParser() - rows = reader.iterate_over_file(filename) + rows = reader.iterate_over_file( + filename, full_descriptions=full_descriptions) return pandas.DataFrame( rows, columns=["sequence_id", "sequence"]) + class FastaParser(object): """ FastaParser object consumes lines of a FASTA file incrementally. @@ -31,7 +46,7 @@ def __init__(self): self.current_id = None self.current_lines = [] - def iterate_over_file(self, fasta_path): + def iterate_over_file(self, fasta_path, full_descriptions=False): """ Generator that yields identifiers paired with sequences. """ @@ -47,7 +62,8 @@ def iterate_over_file(self, fasta_path): if first_char == b">": previous_entry = self._current_entry() - self.current_id = self._parse_header_id(line) + self.current_id = self._parse_header_id( + line, full_description=full_descriptions) if len(self.current_id) == 0: logging.warning( @@ -97,7 +113,7 @@ def open_file(fasta_path): return open(fasta_path, 'rb') @staticmethod - def _parse_header_id(line): + def _parse_header_id(line, full_description=False): """ Pull the transcript or protein identifier from the header line which starts with '>' @@ -112,7 +128,7 @@ def _parse_header_id(line): # split line at first space to get the unique identifier for # this sequence space_index = line.find(b" ") - if space_index >= 0: + if space_index >= 0 and not full_description: identifier = line[1:space_index] else: identifier = line[1:] diff --git a/test/test_class1_processing_neural_network.py b/test/test_class1_processing_neural_network.py index 25bede146..d40a1ab24 100644 --- a/test/test_class1_processing_neural_network.py +++ b/test/test_class1_processing_neural_network.py @@ -25,9 +25,10 @@ teardown = cleanup setup = startup - -table = BLOSUM62_MATRIX.apply( - tuple).reset_index().set_index(0).to_dict()['index'] +table = dict([ + (tuple(encoding), amino_acid) + for amino_acid, encoding in BLOSUM62_MATRIX.iterrows() +]) def decode_matrix(array):