Skip to content

Commit

Permalink
Update tests for fasta reader and loader
Browse files Browse the repository at this point in the history
  • Loading branch information
breimanntools committed May 3, 2024
1 parent 51bd7ee commit 1b6afaa
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 137 deletions.
5 changes: 3 additions & 2 deletions aaanalysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .data_handling import (load_dataset, load_scales, load_features,
to_fasta, read_fasta,
encode_sequences)
filter_seq, encode_seq)
from .feature_engineering import AAclust, AAclustPlot, SequenceFeature, NumericalFeature, CPP, CPPPlot
from .pu_learning import dPULearn, dPULearnPlot
from .pertubation import AAMut, AAMutPlot, SeqMut, SeqMutPlot
Expand All @@ -17,7 +17,8 @@
"load_features",
"to_fasta",
"read_fasta",
"encode_sequences",
"filter_seq",
"encode_seq",
"AAclust",
"AAclustPlot",
"SequenceFeature",
Expand Down
6 changes: 4 additions & 2 deletions aaanalysis/data_handling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from ._load_features import load_features
from ._read_fasta import read_fasta
from ._to_fasta import to_fasta
from ._encode_sequences import encode_sequences
from ._filter_seq import filter_seq
from ._encode_seq import encode_seq

__all__ = [
"load_dataset",
"load_scales",
"load_features",
"to_fasta",
"encode_sequences",
"filter_seq",
"encode_seq",
]
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def _one_hot_encode(amino_acid=None, alphabet=None, gap="_"):

# II Main Functions
# TODO finish, docu, test, example ..
def encode_sequences(sequences=None,
alphabet='ARNDCEQGHILKMFPSTWYV',
gap="_",
pad_at='C'):
def encode_seq(list_seq=None,
alphabet='ARNDCEQGHILKMFPSTWYV',
gap="_",
pad_at='C'):
"""
One-hot-encode a list of protein sequences into a feature matrix, padding shorter sequences
with gaps represented as zero vectors.
Expand Down Expand Up @@ -71,13 +71,13 @@ def encode_sequences(sequences=None,
raise ValueError(f"pad_at must be 'N' or 'C', got {pad_at}")

# Validate if all characters in the sequences are within the given alphabet
all_chars = set(''.join(sequences))
all_chars = set(''.join(list_seq))
if not all_chars.issubset(set(alphabet + '_')):
invalid_chars = all_chars - set(alphabet + '_')
raise ValueError(f"Found invalid amino acid(s) {invalid_chars} not in alphabet.")

# Pad sequences
padded_sequences = _pad_sequences(sequences, pad_at=pad_at)
padded_sequences = _pad_sequences(list_seq, pad_at=pad_at)
max_length = len(padded_sequences[0])
num_amino_acids = len(alphabet)
feature_matrix = np.zeros((len(padded_sequences), max_length * num_amino_acids), dtype=int)
Expand Down
111 changes: 111 additions & 0 deletions aaanalysis/data_handling/_filter_seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import subprocess
import shutil

# TODO test, adjust, finish (see ChatGPT: Model Performance Correlation Analysis including STD)


# I Helper functions
def _is_tool(name):
"""Check whether `name` is on PATH and marked as executable."""
return shutil.which(name) is not None


def _select_longest_representatives(cluster_tsv, all_sequences_file, output_file):
seq_dict = {}
with open(all_sequences_file, 'r') as file:
current_id = None
for line in file:
if line.startswith('>'):
current_id = line.strip().split()[0][1:]
seq_dict[current_id] = ""
else:
seq_dict[current_id] += line.strip()

clusters = {}
with open(cluster_tsv, 'r') as file:
for line in file:
parts = line.strip().split('\t')
cluster_id, seq_id = parts[0], parts[1]
if cluster_id not in clusters:
clusters[cluster_id] = []
clusters[cluster_id].append(seq_id)

representatives = {}
for cluster_id, seq_ids in clusters.items():
longest_seq_id = max(seq_ids, key=lambda x: len(seq_dict[x]))
representatives[longest_seq_id] = seq_dict[longest_seq_id]

with open(output_file, 'w') as out_file:
for seq_id, sequence in representatives.items():
out_file.write(f">{seq_id}\n{sequence}\n")


def _run_cd_hit(input_file, output_file, similarity_threshold, word_size, threads, coverage_long=None, coverage_short=None, verbose=False):
"""Helper function to run CD-HIT with provided parameters."""
cmd = [
"cd-hit", "-i", input_file, "-o", output_file,
"-c", str(similarity_threshold), "-n", str(word_size), "-T", str(threads)
]
if coverage_long:
cmd.extend(["-aL", str(coverage_long)])
if coverage_short:
cmd.extend(["-aS", str(coverage_short)])
if verbose:
cmd.extend(["-d", "0"])
subprocess.run(cmd, check=True)
print("CD-HIT clustering completed. Representatives are saved in:", output_file)


def _run_mmseq(input_file, output_file, similarity_threshold, word_size, threads, verbose=False):
"""Helper function to run MMSeq2 with provided parameters."""
tmp_directory = "tmp"
result_prefix = "result_"
db_name = result_prefix + "DB"
cluster_name = result_prefix + "Clu"

subprocess.run(["mmseqs", "createdb", input_file, db_name], check=True)
cmd = [
"mmseqs", "cluster", db_name, cluster_name, tmp_directory,
"--min-seq-id", str(similarity_threshold), "-k", str(word_size), "--threads", str(threads)
]
if verbose:
cmd.extend(["-v", "3"])

subprocess.run(cmd, check=True)
cluster_tsv = f"{result_prefix}Clu.tsv"
subprocess.run(["mmseqs", "createtsv", db_name, db_name, cluster_name, cluster_tsv], check=True)

_select_longest_representatives(cluster_tsv, input_file, output_file)
print("MMseq2 clustering completed. Representatives are saved in:", output_file)


# II Main function
def filter_seq(method, input_file, output_file, similarity_threshold=0.7, word_size=5, coverage_long=None, coverage_short=None, threads=1, verbose=False):
"""Perform redundancy-reduction of sequences by calling CD-Hit or MMSeq2 algorithm."""
if method not in ['cd-hit', 'mmseq2']:
raise ValueError("Invalid method specified. Use 'cd-hit' or 'mmseq2'.")

if method == 'cd-hit' and not _is_tool('cd-hit'):
raise RuntimeError("CD-HIT is not installed or not in the PATH.")
elif method == 'mmseq2' and not _is_tool('mmseqs'):
raise RuntimeError("MMseq2 is not installed or not in the PATH.")

if method == "cd-hit":
_run_cd_hit(input_file=input_file, output_file=output_file,
similarity_threshold=similarity_threshold, word_size=word_size,
threads=threads, coverage_long=coverage_long, coverage_short=coverage_short, verbose=verbose)
else:
_run_mmseq(input_file=input_file, output_file=output_file,
similarity_threshold=similarity_threshold, word_size=word_size,
threads=threads, verbose=verbose)


# Example usage
"""
input_fasta = ".fasta"
final_output = "representatives.fasta"
method = "cd-hit" # Change to "mmseq2" to use MMseq2
filter_seq(method, input_fasta, final_output, similarity_threshold=0.7, word_size=5,
coverage_long=0.8, coverage_short=0.8, threads=4, verbose=True)
"""

106 changes: 0 additions & 106 deletions aaanalysis/data_handling/_filter_sequences.py

This file was deleted.

18 changes: 12 additions & 6 deletions tests/unit/data_handling_tests/test_read_fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
FILE_IN = "valid_path.fasta"
FILE_DB_IN = "valid_path_db.fasta"
COL_DB = "database"
ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"


def creat_mock_file():
""""""
Expand All @@ -31,21 +33,21 @@ def test_file_path(self):
df = aa.read_fasta(file_path=FILE_IN)
assert isinstance(df, pd.DataFrame) # Expecting a DataFrame to be returned

@given(col_id=st.text(min_size=1))
@given(col_id=st.text(min_size=1, alphabet=ALPHABET))
def test_col_id(self, col_id):
"""Test valid 'col_id' parameter."""
creat_mock_file()
df = aa.read_fasta(file_path=FILE_IN, col_id=col_id)
assert col_id in df.columns

@given(col_seq=st.text(min_size=1))
@given(col_seq=st.text(min_size=1, alphabet=ALPHABET))
def test_col_seq(self, col_seq):
"""Test valid 'col_seq' parameter."""
creat_mock_file()
df = aa.read_fasta(file_path=FILE_IN, col_seq=col_seq)
assert col_seq in df.columns

@given(cols_info=st.lists(st.text(min_size=1), min_size=1, max_size=1))
@given(cols_info=st.lists(st.text(min_size=1, alphabet=ALPHABET), min_size=1, max_size=1))
def test_cols_info(self, cols_info):
"""Test valid 'cols_info' parameter."""
creat_mock_file()
Expand All @@ -61,7 +63,7 @@ def test_col_db(self):
df = aa.read_fasta(file_path=FILE_DB_IN, col_db=COL_DB)
assert COL_DB in df.columns

@given(sep=st.text(min_size=1, max_size=1))
@given(sep=st.text(min_size=1, max_size=1, alphabet=",|;-"))
def test_sep(self, sep):
"""Test valid 'sep' parameter."""
creat_mock_file()
Expand Down Expand Up @@ -113,7 +115,9 @@ def test_invalid_sep(self):
class TestReadFastaComplex:
"""Test aa.read_fasta function with complex scenarios"""

@given(col_id=st.text(min_size=1), col_seq=st.text(min_size=1), sep=st.text(min_size=1, max_size=1))
@given(col_id=st.text(min_size=1, alphabet=ALPHABET),
col_seq=st.text(min_size=1, alphabet=ALPHABET),
sep=st.text(min_size=1, max_size=1, alphabet=",|;-"))
def test_combination_valid_inputs(self, col_id, col_seq, sep):
"""Test valid combinations of parameters."""
creat_mock_file()
Expand All @@ -127,7 +131,9 @@ def test_combination_valid_inputs(self, col_id, col_seq, sep):
except Exception as e:
assert isinstance(e, (FileNotFoundError, ValueError))

@given(col_id=st.text(max_size=0), col_seq=st.text(min_size=1), sep=st.integers())
@given(col_id=st.text(max_size=0, alphabet=ALPHABET),
col_seq=st.text(min_size=1, alphabet=ALPHABET),
sep=st.integers())
def test_combination_invalid(self, col_id, col_seq, sep):
"""Test invalid 'col_id' in combination with other parameters."""
creat_mock_file()
Expand Down
Loading

0 comments on commit 1b6afaa

Please sign in to comment.