diff --git a/aaanalysis/__init__.py b/aaanalysis/__init__.py index 068504e2..f31a77b2 100644 --- a/aaanalysis/__init__.py +++ b/aaanalysis/__init__.py @@ -1,6 +1,6 @@ from .data_handling import (load_dataset, load_scales, load_features, to_fasta, read_fasta, - encode_sequences) + filter_seq, encode_seq) from .feature_engineering import AAclust, AAclustPlot, SequenceFeature, NumericalFeature, CPP, CPPPlot from .pu_learning import dPULearn, dPULearnPlot from .pertubation import AAMut, AAMutPlot, SeqMut, SeqMutPlot @@ -17,7 +17,8 @@ "load_features", "to_fasta", "read_fasta", - "encode_sequences", + "filter_seq", + "encode_seq", "AAclust", "AAclustPlot", "SequenceFeature", diff --git a/aaanalysis/data_handling/__init__.py b/aaanalysis/data_handling/__init__.py index b1632515..94934c97 100644 --- a/aaanalysis/data_handling/__init__.py +++ b/aaanalysis/data_handling/__init__.py @@ -3,12 +3,14 @@ from ._load_features import load_features from ._read_fasta import read_fasta from ._to_fasta import to_fasta -from ._encode_sequences import encode_sequences +from ._filter_seq import filter_seq +from ._encode_seq import encode_seq __all__ = [ "load_dataset", "load_scales", "load_features", "to_fasta", - "encode_sequences", + "filter_seq", + "encode_seq", ] diff --git a/aaanalysis/data_handling/_encode_sequences.py b/aaanalysis/data_handling/_encode_seq.py similarity index 92% rename from aaanalysis/data_handling/_encode_sequences.py rename to aaanalysis/data_handling/_encode_seq.py index cfc4ba09..fa7434d2 100644 --- a/aaanalysis/data_handling/_encode_sequences.py +++ b/aaanalysis/data_handling/_encode_seq.py @@ -39,10 +39,10 @@ def _one_hot_encode(amino_acid=None, alphabet=None, gap="_"): # II Main Functions # TODO finish, docu, test, example .. -def encode_sequences(sequences=None, - alphabet='ARNDCEQGHILKMFPSTWYV', - gap="_", - pad_at='C'): +def encode_seq(list_seq=None, + alphabet='ARNDCEQGHILKMFPSTWYV', + gap="_", + pad_at='C'): """ One-hot-encode a list of protein sequences into a feature matrix, padding shorter sequences with gaps represented as zero vectors. @@ -71,13 +71,13 @@ def encode_sequences(sequences=None, raise ValueError(f"pad_at must be 'N' or 'C', got {pad_at}") # Validate if all characters in the sequences are within the given alphabet - all_chars = set(''.join(sequences)) + all_chars = set(''.join(list_seq)) if not all_chars.issubset(set(alphabet + '_')): invalid_chars = all_chars - set(alphabet + '_') raise ValueError(f"Found invalid amino acid(s) {invalid_chars} not in alphabet.") # Pad sequences - padded_sequences = _pad_sequences(sequences, pad_at=pad_at) + padded_sequences = _pad_sequences(list_seq, pad_at=pad_at) max_length = len(padded_sequences[0]) num_amino_acids = len(alphabet) feature_matrix = np.zeros((len(padded_sequences), max_length * num_amino_acids), dtype=int) diff --git a/aaanalysis/data_handling/_filter_seq.py b/aaanalysis/data_handling/_filter_seq.py new file mode 100644 index 00000000..bfae9ebc --- /dev/null +++ b/aaanalysis/data_handling/_filter_seq.py @@ -0,0 +1,111 @@ +import subprocess +import shutil + +# TODO test, adjust, finish (see ChatGPT: Model Performance Correlation Analysis including STD) + + +# I Helper functions +def _is_tool(name): + """Check whether `name` is on PATH and marked as executable.""" + return shutil.which(name) is not None + + +def _select_longest_representatives(cluster_tsv, all_sequences_file, output_file): + seq_dict = {} + with open(all_sequences_file, 'r') as file: + current_id = None + for line in file: + if line.startswith('>'): + current_id = line.strip().split()[0][1:] + seq_dict[current_id] = "" + else: + seq_dict[current_id] += line.strip() + + clusters = {} + with open(cluster_tsv, 'r') as file: + for line in file: + parts = line.strip().split('\t') + cluster_id, seq_id = parts[0], parts[1] + if cluster_id not in clusters: + clusters[cluster_id] = [] + clusters[cluster_id].append(seq_id) + + representatives = {} + for cluster_id, seq_ids in clusters.items(): + longest_seq_id = max(seq_ids, key=lambda x: len(seq_dict[x])) + representatives[longest_seq_id] = seq_dict[longest_seq_id] + + with open(output_file, 'w') as out_file: + for seq_id, sequence in representatives.items(): + out_file.write(f">{seq_id}\n{sequence}\n") + + +def _run_cd_hit(input_file, output_file, similarity_threshold, word_size, threads, coverage_long=None, coverage_short=None, verbose=False): + """Helper function to run CD-HIT with provided parameters.""" + cmd = [ + "cd-hit", "-i", input_file, "-o", output_file, + "-c", str(similarity_threshold), "-n", str(word_size), "-T", str(threads) + ] + if coverage_long: + cmd.extend(["-aL", str(coverage_long)]) + if coverage_short: + cmd.extend(["-aS", str(coverage_short)]) + if verbose: + cmd.extend(["-d", "0"]) + subprocess.run(cmd, check=True) + print("CD-HIT clustering completed. Representatives are saved in:", output_file) + + +def _run_mmseq(input_file, output_file, similarity_threshold, word_size, threads, verbose=False): + """Helper function to run MMSeq2 with provided parameters.""" + tmp_directory = "tmp" + result_prefix = "result_" + db_name = result_prefix + "DB" + cluster_name = result_prefix + "Clu" + + subprocess.run(["mmseqs", "createdb", input_file, db_name], check=True) + cmd = [ + "mmseqs", "cluster", db_name, cluster_name, tmp_directory, + "--min-seq-id", str(similarity_threshold), "-k", str(word_size), "--threads", str(threads) + ] + if verbose: + cmd.extend(["-v", "3"]) + + subprocess.run(cmd, check=True) + cluster_tsv = f"{result_prefix}Clu.tsv" + subprocess.run(["mmseqs", "createtsv", db_name, db_name, cluster_name, cluster_tsv], check=True) + + _select_longest_representatives(cluster_tsv, input_file, output_file) + print("MMseq2 clustering completed. Representatives are saved in:", output_file) + + +# II Main function +def filter_seq(method, input_file, output_file, similarity_threshold=0.7, word_size=5, coverage_long=None, coverage_short=None, threads=1, verbose=False): + """Perform redundancy-reduction of sequences by calling CD-Hit or MMSeq2 algorithm.""" + if method not in ['cd-hit', 'mmseq2']: + raise ValueError("Invalid method specified. Use 'cd-hit' or 'mmseq2'.") + + if method == 'cd-hit' and not _is_tool('cd-hit'): + raise RuntimeError("CD-HIT is not installed or not in the PATH.") + elif method == 'mmseq2' and not _is_tool('mmseqs'): + raise RuntimeError("MMseq2 is not installed or not in the PATH.") + + if method == "cd-hit": + _run_cd_hit(input_file=input_file, output_file=output_file, + similarity_threshold=similarity_threshold, word_size=word_size, + threads=threads, coverage_long=coverage_long, coverage_short=coverage_short, verbose=verbose) + else: + _run_mmseq(input_file=input_file, output_file=output_file, + similarity_threshold=similarity_threshold, word_size=word_size, + threads=threads, verbose=verbose) + + +# Example usage +""" +input_fasta = ".fasta" +final_output = "representatives.fasta" +method = "cd-hit" # Change to "mmseq2" to use MMseq2 +filter_seq(method, input_fasta, final_output, similarity_threshold=0.7, word_size=5, + coverage_long=0.8, coverage_short=0.8, threads=4, verbose=True) +""" + diff --git a/aaanalysis/data_handling/_filter_sequences.py b/aaanalysis/data_handling/_filter_sequences.py deleted file mode 100644 index 6c0730e0..00000000 --- a/aaanalysis/data_handling/_filter_sequences.py +++ /dev/null @@ -1,106 +0,0 @@ -import subprocess -import shutil - -# TODO test, adjust, finish (see ChatGPT: Model Performance Correlation Analysis including STD) - - -# I Helper functions -def _is_tool(name): - """Check whether `name` is on PATH and marked as executable.""" - return shutil.which(name) is not None - - -def _select_longest_representatives(cluster_tsv, all_sequences_file, output_file): - seq_dict = {} - with open(all_sequences_file, 'r') as file: - current_id = None - for line in file: - if line.startswith('>'): - current_id = line.strip().split()[0][1:] - seq_dict[current_id] = "" - else: - seq_dict[current_id] += line.strip() - - clusters = {} - with open(cluster_tsv, 'r') as file: - for line in file: - parts = line.strip().split('\t') - cluster_id, seq_id = parts[0], parts[1] - if cluster_id not in clusters: - clusters[cluster_id] = [] - clusters[cluster_id].append(seq_id) - - representatives = {} - for cluster_id, seq_ids in clusters.items(): - longest_seq_id = max(seq_ids, key=lambda x: len(seq_dict[x])) - representatives[longest_seq_id] = seq_dict[longest_seq_id] - - with open(output_file, 'w') as out_file: - for seq_id, sequence in representatives.items(): - out_file.write(f">{seq_id}\n{sequence}\n") - - -# II Main Functions -# TODO finish, docu, test, example .. -def filter_sequences(method, input_file, output_file, similarity_threshold=0.7, word_size=5, - coverage_long=None, coverage_short=None, threads=1, verbose=False): - """Perform redundancy-reduction of sequences by calling CD-Hit or MMSeq2 algorithm""" - if method not in ['cd-hit', 'mmseq2']: - raise ValueError("Invalid method specified. Use 'cd-hit' or 'mmseq2'.") - - if method == 'cd-hit' and not _is_tool('cd-hit'): - raise RuntimeError("CD-HIT is not installed or not in the PATH.") - if method == 'mmseq2' and not _is_tool('mmseqs'): - raise RuntimeError("MMseq2 is not installed or not in the PATH.") - - if method == "cd-hit": - cmd = [ - "cd-hit", "-i", input_file, "-o", output_file, - "-c", str(similarity_threshold), "-n", str(word_size), - "-T", str(threads) - ] - if coverage_long: - cmd.extend(["-aL", str(coverage_long)]) - if coverage_short: - cmd.extend(["-aS", str(coverage_short)]) - if verbose: - cmd.append("-d") - cmd.append("0") - - subprocess.run(cmd, check=True) - print("CD-HIT clustering completed. Representatives are saved in:", output_file) - - elif method == "mmseq2": - tmp_directory = "tmp" - result_prefix = "result_" - db_name = result_prefix + "DB" - cluster_name = result_prefix + "Clu" - subprocess.run(["mmseqs", "createdb", input_file, db_name], check=True) - cmd = [ - "mmseqs", "cluster", db_name, cluster_name, tmp_directory, - "--min-seq-id", str(similarity_threshold), "-k", str(word_size), - "--threads", str(threads) - ] - if verbose: - cmd.append("-v") - cmd.append("3") - - subprocess.run(cmd, check=True) - cluster_tsv = result_prefix + "Clu.tsv" - subprocess.run([ - "mmseqs", "createtsv", db_name, db_name, cluster_name, cluster_tsv - ], check=True) - - _select_longest_representatives(cluster_tsv, input_file, output_file) - print("MMseq2 clustering completed. Representatives are saved in:", output_file) - - -# Example usage -input_fasta = "your_input_sequences.fasta" -final_output = "representatives.fasta" -method = "cd-hit" # Change to "mmseq2" to use MMseq2 - -filter_sequences(method, input_fasta, final_output, similarity_threshold=0.7, word_size=5, - coverage_long=0.8, coverage_short=0.8, threads=4, verbose=True) - - diff --git a/tests/unit/data_handling_tests/test_read_fasta.py b/tests/unit/data_handling_tests/test_read_fasta.py index a20a6679..fc4048dc 100644 --- a/tests/unit/data_handling_tests/test_read_fasta.py +++ b/tests/unit/data_handling_tests/test_read_fasta.py @@ -11,6 +11,8 @@ FILE_IN = "valid_path.fasta" FILE_DB_IN = "valid_path_db.fasta" COL_DB = "database" +ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + def creat_mock_file(): """""" @@ -31,21 +33,21 @@ def test_file_path(self): df = aa.read_fasta(file_path=FILE_IN) assert isinstance(df, pd.DataFrame) # Expecting a DataFrame to be returned - @given(col_id=st.text(min_size=1)) + @given(col_id=st.text(min_size=1, alphabet=ALPHABET)) def test_col_id(self, col_id): """Test valid 'col_id' parameter.""" creat_mock_file() df = aa.read_fasta(file_path=FILE_IN, col_id=col_id) assert col_id in df.columns - @given(col_seq=st.text(min_size=1)) + @given(col_seq=st.text(min_size=1, alphabet=ALPHABET)) def test_col_seq(self, col_seq): """Test valid 'col_seq' parameter.""" creat_mock_file() df = aa.read_fasta(file_path=FILE_IN, col_seq=col_seq) assert col_seq in df.columns - @given(cols_info=st.lists(st.text(min_size=1), min_size=1, max_size=1)) + @given(cols_info=st.lists(st.text(min_size=1, alphabet=ALPHABET), min_size=1, max_size=1)) def test_cols_info(self, cols_info): """Test valid 'cols_info' parameter.""" creat_mock_file() @@ -61,7 +63,7 @@ def test_col_db(self): df = aa.read_fasta(file_path=FILE_DB_IN, col_db=COL_DB) assert COL_DB in df.columns - @given(sep=st.text(min_size=1, max_size=1)) + @given(sep=st.text(min_size=1, max_size=1, alphabet=",|;-")) def test_sep(self, sep): """Test valid 'sep' parameter.""" creat_mock_file() @@ -113,7 +115,9 @@ def test_invalid_sep(self): class TestReadFastaComplex: """Test aa.read_fasta function with complex scenarios""" - @given(col_id=st.text(min_size=1), col_seq=st.text(min_size=1), sep=st.text(min_size=1, max_size=1)) + @given(col_id=st.text(min_size=1, alphabet=ALPHABET), + col_seq=st.text(min_size=1, alphabet=ALPHABET), + sep=st.text(min_size=1, max_size=1, alphabet=",|;-")) def test_combination_valid_inputs(self, col_id, col_seq, sep): """Test valid combinations of parameters.""" creat_mock_file() @@ -127,7 +131,9 @@ def test_combination_valid_inputs(self, col_id, col_seq, sep): except Exception as e: assert isinstance(e, (FileNotFoundError, ValueError)) - @given(col_id=st.text(max_size=0), col_seq=st.text(min_size=1), sep=st.integers()) + @given(col_id=st.text(max_size=0, alphabet=ALPHABET), + col_seq=st.text(min_size=1, alphabet=ALPHABET), + sep=st.integers()) def test_combination_invalid(self, col_id, col_seq, sep): """Test invalid 'col_id' in combination with other parameters.""" creat_mock_file() diff --git a/tests/unit/data_handling_tests/test_to_fasta.py b/tests/unit/data_handling_tests/test_to_fasta.py index d0971110..93cb9e76 100644 --- a/tests/unit/data_handling_tests/test_to_fasta.py +++ b/tests/unit/data_handling_tests/test_to_fasta.py @@ -8,6 +8,7 @@ settings.load_profile("default") FILE_OUT = "valid_path_out.fasta" +ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" class TestToFasta: @@ -19,31 +20,31 @@ def test_file_path_valid(self): df_seq = aa.load_dataset(name="SEQ_AMYLO", n=10) aa.to_fasta(df_seq=df_seq, file_path=FILE_OUT) - @given(col_id=st.text(min_size=1)) + @given(col_id=st.text(min_size=1, alphabet=ALPHABET)) def test_col_id_valid(self, col_id): """Test valid 'col_id' ensuring it exists in the DataFrame.""" df = pd.DataFrame({col_id: ["id1"], "sequence": ["ATCG"]}) aa.to_fasta(df_seq=df, file_path=FILE_OUT, col_id=col_id) - @given(col_seq=st.text(min_size=1)) + @given(col_seq=st.text(min_size=1, alphabet=ALPHABET)) def test_col_seq_valid(self, col_seq): """Test valid 'col_seq' ensuring it exists in the DataFrame.""" df = pd.DataFrame({"entry": ["id1"], col_seq: ["ATCG"]}) aa.to_fasta(df_seq=df, file_path=FILE_OUT, col_seq=col_seq) - @given(sep=st.text(min_size=1, max_size=1)) + @given(sep=st.text(min_size=1, max_size=1, alphabet=",|;-")) def test_sep_valid(self, sep): """Test valid 'sep' to check if it correctly separates information.""" df = pd.DataFrame({"entry": ["id1"], "sequence": ["ATCG"]}) aa.to_fasta(df_seq=df, file_path=FILE_OUT, sep=sep) - @given(col_db=st.text(min_size=1)) + @given(col_db=st.text(min_size=1, alphabet=ALPHABET)) def test_col_db_valid(self, col_db): """Test valid 'col_db' ensuring it is correctly added to the header.""" df = pd.DataFrame({"entry": ["id1"], "sequence": ["ATCG"], col_db: ["DB001"]}) aa.to_fasta(df_seq=df, file_path=FILE_OUT, col_db=col_db) - @given(cols_info=st.lists(st.text(min_size=1), min_size=1, max_size=3)) + @given(cols_info=st.lists(st.text(min_size=1, alphabet=ALPHABET), min_size=1, max_size=3)) def test_cols_info_valid(self, cols_info): """Test valid 'cols_info' ensuring they are correctly added to the header.""" df = pd.DataFrame({"entry": ["id1"], "sequence": ["ATCG"], **{col: ["info"] for col in cols_info}}) @@ -104,22 +105,22 @@ def test_cols_info_invalid(self): class TestToFastaComplex: """Test complex scenarios involving multiple parameters in the 'to_fasta' function.""" - @given(col_id=st.text(min_size=1), - col_seq=st.text(min_size=1), - sep=st.text(min_size=1, max_size=1), - col_db=st.text(min_size=1), - cols_info=st.lists(st.text(min_size=1), min_size=1, max_size=3)) + @given(col_id=st.text(min_size=1, alphabet=ALPHABET), + col_seq=st.text(min_size=1, alphabet=ALPHABET), + sep=st.text(min_size=1, max_size=1, alphabet=ALPHABET), + col_db=st.text(min_size=1, alphabet=ALPHABET), + cols_info=st.lists(st.text(min_size=1, alphabet=ALPHABET), min_size=1, max_size=3)) def test_valid_combination_all_parameters(self, col_id, col_seq, sep, col_db, cols_info): """Test with all parameters valid to ensure correct header formation and sequence output.""" df = pd.DataFrame({col_id: ["id1"], col_seq: ["ATCG"], col_db: ["DB001"], **{col: ["info"] for col in cols_info}}) aa.to_fasta(df_seq=df, file_path=FILE_OUT, col_id=col_id, col_seq=col_seq, sep=sep, col_db=col_db, cols_info=cols_info) - @given(col_id=st.text(max_size=0), # Invalid col_id (empty string) - col_seq=st.text(min_size=1), + @given(col_id=st.text(max_size=0, alphabet=ALPHABET), + col_seq=st.text(min_size=1, alphabet=ALPHABET), sep=st.integers(), - col_db=st.text(min_size=1), - cols_info=st.lists(st.text(min_size=1), min_size=1, max_size=3)) + col_db=st.text(min_size=1, alphabet=ALPHABET), + cols_info=st.lists(st.text(min_size=1, alphabet=ALPHABET), min_size=1, max_size=3)) def test_invalid_combination_parameters(self, col_id, col_seq, sep, col_db, cols_info): """Test invalid parameter combinations to ensure proper error handling.""" df = pd.DataFrame({col_id: ["id1"], col_seq: ["ATCG"], col_db: ["DB001"], **{col: ["info"] for col in cols_info}}) diff --git a/tests/unit/data_handling_tests/valid_path_out.fasta b/tests/unit/data_handling_tests/valid_path_out.fasta index e6b9ad23..3c07d894 100644 --- a/tests/unit/data_handling_tests/valid_path_out.fasta +++ b/tests/unit/data_handling_tests/valid_path_out.fasta @@ -1,2 +1,2 @@ ->DB0010ATCG0info0info +>DB001Bid1BinfoBinfo ATCG