diff --git a/.gitignore b/.gitignore index 17b9762..17a8200 100755 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,4 @@ data/ [._]ss[a-gi-z] [._]sw[a-p] -**/.DS_Store \ No newline at end of file +**/.DS_Store diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c982118 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +repos: +- repo: local + hooks: + - id: unittests + name: run unit tests + entry: python -m unittest + language: system + pass_filenames: false + args: ["discover"] +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/psf/black + rev: 24.3.0 + hooks: + - id: black diff --git a/CITATION.cff b/CITATION.cff index a2af613..b6585f6 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,6 +1,7 @@ cff-version: 1.2.0 message: "If you use this software, please cite it as below." title: "RNA3DB: A dataset for training and benchmarking deep learning models for RNA structure prediction" +version: 1.1 authors: - given-names: "Marcell" family-names: "Szikszai" @@ -15,3 +16,24 @@ authors: - given-names: "Elena family-names: Rivas" url: "https://github.com/marcellszi/rna3db" +doi: "10.1016/j.jmb.2024.168552" +date-released: 2024-04-26 +preferred-citation: + type: article + authors: + - given-names: "Marcell" + family-names: "Szikszai" + - given-names: "Marcin" + family-names: Magnus + - given-names: "Siddhant" + family-names: "Sanghi" + - given-names: "Sachin" + family-names: "Kadyan" + - given-names: "Nazim" + family-names: "Bouatta" + - given-names: "Elena" + family-names: Rivas" + doi: "10.1016/j.jmb.2024.168552" + journal: "Journal of Molecular Biology" + title: "RNA3DB: A structurally-dissimilar dataset split for training and benchmarking deep learning models for RNA structure prediction" + year: 2024 diff --git a/rna3db/__main__.py b/rna3db/__main__.py index fb4eeaf..0dcefb4 100755 --- a/rna3db/__main__.py +++ b/rna3db/__main__.py @@ -132,7 +132,16 @@ def main(args): args.input, args.output, args.tbl_dir, args.structural_e_value_cutoff ) elif args.command == "split": - split(args.input, args.output, args.train_percentage, args.force_zero_test) + split( + args.input, + args.output, + splits=[ + args.train_ratio, + args.valid_ratio, + 1 - args.train_ratio - args.valid_ratio, + ], + force_zero_last=args.force_zero_test, + ) else: raise ValueError @@ -246,10 +255,16 @@ def main(args): split_parser.add_argument("input", type=Path, help="Input JSON file") split_parser.add_argument("output", type=Path, help="Output JSON file") split_parser.add_argument( - "--train_percentage", + "--train_ratio", type=float, - default=0.3, - help="Percentage of data for the train set", + default=0.7, + help="Ratio of data to use for the training set", + ) + split_parser.add_argument( + "--valid_ratio", + type=float, + default=0.0, + help="Ratio of the data to use for the validation set", ) split_parser.add_argument( "--force_zero_test", diff --git a/rna3db/parser.py b/rna3db/parser.py index 1652610..bf38334 100644 --- a/rna3db/parser.py +++ b/rna3db/parser.py @@ -135,6 +135,10 @@ def __getitem__(self, idx): def __len__(self): return len(self.residues) + @property + def has_atoms(self): + return any([not res.is_missing for res in self]) + def add_residue(self, res: Residue): """Add a residue to the chain. @@ -341,6 +345,182 @@ def __repr__(self): f"resolution={self.resolution}, release_date={self.release_date}, structure_method={self.structure_method})" ) + @staticmethod + def _gen_mmcif_loop_str(name: str, headers: Sequence[str], values: Sequence[tuple]): + s = "#\nloop_\n" + for header in headers: + s += f"_{name}.{header}\n" + + max_widths = {k: 0 for k in headers} + for V in values: + for k, v in zip(headers, V): + max_widths[k] = max(max_widths[k], len(str(v))) + + for V in values: + row = "" + for k, v in zip(headers, V): + row += f"{str(v):<{max_widths[k]}} " + s += row + "\n" + + return s + + def write_mmcif_chain(self, output_path, author_id): + if not self[author_id].has_atoms: + raise ValueError( + f"Did not find any atoms for chain {author_id}. Did you set `include_atoms=True`?" + ) + # extract needed info + entity_poly_seq_data = [] + atom_site_data = [] + for i, res in enumerate(self[author_id]): + entity_poly_seq_data.append((1, res.index + 1, res.code, "n")) + for idx, (atom_name, atom_coords) in enumerate(res.atoms.items()): + x, y, z = atom_coords + atom_site_data.append( + ( + "ATOM", + idx + 1, + atom_name[0], + atom_name, + ".", + res.code, + author_id, + "?", + i + 1, + "?", + x, + y, + z, + 1.0, + 0.0, + "?", + i + 1, + res.code, + author_id, + atom_name, + 1, + ) + ) + + # build required strings + header_str = ( + f"# generated by rna3db\n" + f"#\n" + f"data_{self.pdb_id}_{author_id}\n" + f"_entry.id {self.pdb_id}_{author_id}\n" + f"_pdbx_database_status.recvd_initial_deposition_date {self.release_date}\n" + f"_exptl.method '{self.structure_method.upper()}'\n" + f"_reflns.d_resolution_high {self.resolution}\n" + f"_entity_poly.pdbx_seq_one_letter_code_can {self[author_id].sequence}\n" + ) + struct_asym_str = StructureFile._gen_mmcif_loop_str( + "_struct_asym", + [ + "id", + "pdbx_blank_PDB_chainid_flag", + "pdbx_modified", + "entity_id", + "details", + ], + [("A", "N", "N", 1, "?")], + ) + chem_comp_str = StructureFile._gen_mmcif_loop_str( + "_chem_comp", + [ + "id", + "type", + "mon_nstd_flag", + "pdbx_synonyms", + "formula", + "formula_weight", + ], + [ + ( + "A", + "'RNA linking'", + "y", + '"ADENOSINE-5\'-MONOPHOSPHATE"', + "?", + "'C10 H14 N5 O7 P'", + 347.221, + ), + ( + "C", + "'RNA linking'", + "y", + '"CYTIDINE-5\'-MONOPHOSPHATE"', + "?", + "'C9 H14 N3 O8 P'", + 323.197, + ), + ( + "G", + "'RNA linking'", + "y", + '"GUANOSINE-5\'-MONOPHOSPHATE"', + "?", + "'C9 H13 N2 O9 P'", + 363.221, + ), + ( + "U", + "'RNA linking'", + "y", + '"URIDINE-5\'-MONOPHOSPHATE"', + "?", + "'C9 H13 N2 O9 P'", + 324.181, + ), + ("T", "'RNA linking'", "y", '"T"', "?", "''", 0), + ("N", "'RNA linking'", "y", '"N"', "?", "''", 0), + ], + ) + entity_poly_seq_str = StructureFile._gen_mmcif_loop_str( + "entity_poly_seq", + [ + "entity_id", + "num", + "mon_id", + "heter", + ], + entity_poly_seq_data, + ) + atom_site_str = StructureFile._gen_mmcif_loop_str( + "atom_site", + [ + "group_PDB", + "id", + "type_symbol", + "label_atom_id", + "label_alt_id", + "label_comp_id", + "label_asym_id", + "label_entity_id", + "label_seq_id", + "pdbx_PDB_ins_code", + "Cartn_x", + "Cartn_y", + "Cartn_z", + "occupancy", + "B_iso_or_equiv", + "pdbx_formal_charge", + "auth_seq_id", + "auth_comp_id", + "auth_asym_id", + "auth_atom_id", + "pdbx_PDB_model_num", + ], + atom_site_data, + ) + + # write to file + with open(output_path, "w") as f: + f.write(header_str) + f.write(struct_asym_str) + f.write(chem_comp_str) + f.write(entity_poly_seq_str) + f.write(atom_site_str) + class mmCIFParser: def __init__( diff --git a/rna3db/split.py b/rna3db/split.py index 8f16f49..6e27ae9 100644 --- a/rna3db/split.py +++ b/rna3db/split.py @@ -1,11 +1,32 @@ +import random + +from typing import Sequence + from rna3db.utils import PathLike, read_json, write_json +def find_optimal_components(lengths_dict, capacity): + component_name = list(lengths_dict.keys()) + lengths = list(lengths_dict.values()) + + dp = [0] * (capacity + 1) + trace = [[] for i in range(capacity + 1)] + for i in range(len(lengths)): + for j in range(capacity, lengths[i] - 1, -1): + if dp[j] < dp[j - lengths[i]] + lengths[i]: + dp[j] = dp[j - lengths[i]] + lengths[i] + trace[j] = trace[j - lengths[i]] + [component_name[i]] + + return set(trace[capacity]) + + def split( input_path: PathLike, output_path: PathLike, - train_size: float = 0.7, - force_zero_test: bool = True, + splits: Sequence[float] = [0.7, 0.0, 0.3], + split_names: Sequence[str] = ["train_set", "valid_set", "test_set"], + shuffle: bool = False, + force_zero_last: bool = False, ): """A function that splits a JSON of components into a train/test set. @@ -16,35 +37,41 @@ def split( Args: input_path (PathLike): path to JSON containing components output_path (PathLike): path to output JSON - train_size (float): percentage of data to use as training set - force_zero_test (bool): whether to force component_0 into the test set """ + if sum(splits) != 1.0: + raise ValueError("Sum of splits must equal 1.0.") + # read json cluster_json = read_json(input_path) + # count number of repr sequences - total_repr_clusters = sum(len(v) for v in cluster_json.values()) - - # figure out which components need to go into training set - train_components = set() - train_set_length = 0 - i = 1 if force_zero_test else 0 - while train_set_length / total_repr_clusters < train_size: - # skip if it's not a real component (should only happen with 0) - if f"component_{i}" not in cluster_json: - i += 1 - continue - train_components.add(f"component_{i}") - train_set_length += len(cluster_json[f"component_{i}"].keys()) - i += 1 - - # test_components are just total-train_components - test_components = set(cluster_json.keys()) - train_components - - # actually build JSON - output = {"train_set": {}, "test_set": {}} - for k in sorted(train_components): - output["train_set"][k] = cluster_json[k] - for k in sorted(test_components): - output["test_set"][k] = cluster_json[k] + lengths = {k: len(v) for k, v in cluster_json.items()} + total_repr_clusters = sum(lengths.values()) + + # shuffle if we want to add randomness + if shuffle: + L = list(zip(component_name, lengths)) + random.shuffle(L) + component_name, lengths = zip(*L) + component_name, lengths = list(component_name), list(lengths) + + output = {k: {} for k in split_names} + + if force_zero_last: + output[split_names[-1]]["component_0"] = cluster_json["component_0"] + lengths.pop("component_0") + + # start with the smallest splits + # splits, split_names = zip(*sorted(zip(splits, split_names), reverse=True)) + + capacities = [int(total_repr_clusters * ratio) for ratio in splits] + for name, capacity in zip(split_names, capacities): + components = find_optimal_components(lengths, capacity) + print(name, capacity, components) + for k in sorted(components): + lengths.pop(k) + output[name][k] = cluster_json[k] + + assert len(lengths) == 0 write_json(output, output_path) diff --git a/scripts/build_incremental_release_fasta.py b/scripts/build_incremental_release_fasta.py new file mode 100644 index 0000000..344670b --- /dev/null +++ b/scripts/build_incremental_release_fasta.py @@ -0,0 +1,26 @@ +from rna3db.utils import read_json +from rna3db.parser import write_fasta + +from collections import defaultdict +from pathlib import Path + +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Extract only new sequences from two parse outputs, and write them to a FASTA. " + ) + parser.add_argument("old_path", type=Path) + parser.add_argument("new_path", type=Path) + parser.add_argument("output_path", type=Path) + args = parser.parse_args() + + old_parse = read_json(args.old_path) + new_parse = read_json(args.new_path) + + descriptions, sequences = [], [] + for k in set(new_parse.keys()) - set(old_parse.keys()): + descriptions.append(k) + sequences.append(new_parse[k]["sequence"]) + + write_fasta(descriptions, sequences, args.output_path) diff --git a/scripts/get_nohits.py b/scripts/get_nohits.py index 11b4cfa..9df583e 100644 --- a/scripts/get_nohits.py +++ b/scripts/get_nohits.py @@ -1,5 +1,5 @@ from rna3db.tabular import read_tbls_from_dir -from rna3db.utils import read_json +from rna3db.parser import parse_fasta, write_fasta from pathlib import Path import argparse @@ -13,11 +13,16 @@ parser.add_argument("tbls_path", type=Path) args = parser.parse_args() - parse_json = read_json(args.input_path) - all_chains = set(parse_json.keys()) + all_chains, all_sequences = parse_fasta(args.input_path) + all_dict = {k: v for k, v in zip(all_chains, all_sequences)} + all_chains = set(all_chains) + tbl = read_tbls_from_dir(args.tbls_path) hit_chains = set(tbl.query_name) - with open(args.output_path, "w") as f: - for nohit in all_chains - hit_chains: - f.write(f'>{nohit}\n{parse_json[nohit]["sequence"]}\n') + nohits = all_chains - hit_chains + + output_descriptions = list(nohits) + output_sequences = [all_dict[i] for i in nohits] + + write_fasta(output_descriptions, output_sequences, args.output_path) diff --git a/scripts/json_to_mmcif.py b/scripts/json_to_mmcif.py index 65d0a8b..5f6f09e 100644 --- a/scripts/json_to_mmcif.py +++ b/scripts/json_to_mmcif.py @@ -1,234 +1,54 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- from rna3db.utils import read_json +from rna3db.parser import parse_file from pathlib import Path +from tqdm import tqdm import argparse -import os -def format_line( - atom_index, atom_name, residue_name, chain_id, residue_index, x, y, z, verbose=False -): - group_PDB = "ATOM" - atom_id = atom_index - type_symbol = atom_name[0] - label_atom_id = atom_name - label_alt_id = "." - label_comp_id = residue_name - label_asym_id = chain_id - label_entity_id = "?" - label_seq_id = residue_index - pdbx_PDB_ins_code = "?" - Cartn_x = x - Cartn_y = y - Cartn_z = z - occupancy = 1.0 - B_iso_or_equiv = 0.0 - auth_seq_id = residue_index - auth_asym_id = chain_id - pdbx_PDB_model_num = "1" - mmcif_format = ( - f"{group_PDB:<6}" - f"{atom_id:>7} " - f"{type_symbol:>2} " - f"{label_atom_id:<4} " - f"{label_alt_id:<1} " - f"{label_comp_id:>3} " - f"{label_asym_id:>2} " - f"{label_entity_id:>1} " - f"{label_seq_id:>3} " - f"{pdbx_PDB_ins_code:>2} " - f"{Cartn_x:>8.3f} " - f"{Cartn_y:>8.3f} " - f"{Cartn_z:>8.3f} " - f"{occupancy} " - f"{B_iso_or_equiv:>6.2f} " - f"? " - f"{auth_seq_id:>4} " - f"{label_comp_id:>3} " - f"{label_asym_id:>2} " - f"{label_atom_id:>4}" - f"{pdbx_PDB_model_num:>2}" +def main(args): + data = read_json(args.input_json_path) + num_chains = sum( + len(i.keys()) for s in data.values() for c in s.values() for i in c.values() ) - # Now we replace the variables in the mmCIF format - formatted_mmcif_line = mmcif_format.format( - group_PDB=group_PDB, - atom_id=atom_id, - type_symbol=type_symbol, - label_atom_id=label_atom_id, - label_alt_id=label_alt_id, - label_comp_id=label_comp_id, - label_asym_id=label_asym_id, - label_entity_id=label_entity_id, - label_seq_id=label_seq_id, - pdbx_PDB_ins_code=pdbx_PDB_ins_code, - Cartn_x=Cartn_x, - Cartn_y=Cartn_y, - Cartn_z=Cartn_z, - occupancy=occupancy, - B_iso_or_equiv=B_iso_or_equiv, - auth_seq_id=auth_seq_id, - auth_asym_id=auth_asym_id, - pdbx_PDB_model_num=pdbx_PDB_model_num, - ) - - if verbose: - print(formatted_mmcif_line) - return formatted_mmcif_line + "\n" - - -def run(data, output_path, verbose = False): - log = "" - for dataset, value in data.items(): - dataset_pdb_index = 1 - try: - os.makedirs(output_path + os.sep + dataset) - except FileExistsError: - pass - for component, value in data[dataset].items(): - for cluster, value in data[dataset][component].items(): - for pdb, value in data[dataset][component][cluster].items(): - release_date = value["release_date"] - resolution = value["resolution"] - - structure_method = "" - - if "electron microscopy" in value["structure_method"]: - structure_method = "'ELECTRON MICROSCOPY'" - if "x-ray diffraction" == value["structure_method"]: - structure_method = "'X-RAY DIFFRACTION'" - if not structure_method: - print("error: no method") - - pdb_id, chain_id = pdb.split("_") - - cif_txt = f"""# generated by rna3db -# -data_{pdb_id}_{chain_id} -_entry.id {pdb_id}_{chain_id} -_pdbx_database_status.recvd_initial_deposition_date {release_date} -_exptl.method {structure_method} -_reflns.d_resolution_high {resolution} -_entity_poly.pdbx_seq_one_letter_code_can {value['sequence']} -# -loop_ -_struct_asym.id -_struct_asym.pdbx_blank_PDB_chainid_flag -_struct_asym.pdbx_modified -_struct_asym.entity_id -_struct_asym.details -{chain_id} N N 1 ? -## -loop_ -_chem_comp.id -_chem_comp.type -_chem_comp.mon_nstd_flag -_chem_comp.name -_chem_comp.pdbx_synonyms -_chem_comp.formula -_chem_comp.formula_weight -A 'RNA linking' y "ADENOSINE-5'-MONOPHOSPHATE" ? 'C10 H14 N5 O7 P' 347.221 -C 'RNA linking' y "CYTIDINE-5'-MONOPHOSPHATE" ? 'C9 H14 N3 O8 P' 323.197 -G 'RNA linking' y "GUANOSINE-5'-MONOPHOSPHATE" ? 'C10 H14 N5 O8 P' 363.221 -U 'RNA linking' y "URIDINE-5'-MONOPHOSPHATE" ? 'C9 H13 N2 O9 P' 324.181 -T 'RNA linking' y "T" ? '' 0 -N 'RNA linking' y "N" ? '' 0 -""" - - cif_txt_poly_seq = f"""# -# -loop_ -_entity_poly_seq.entity_id -_entity_poly_seq.num -_entity_poly_seq.mon_id -_entity_poly_seq.hetero -""" - cif_txt_atom_site = f""" -# -loop_ -_atom_site.group_PDB -_atom_site.id -_atom_site.type_symbol -_atom_site.label_atom_id -_atom_site.label_alt_id -_atom_site.label_comp_id -_atom_site.label_asym_id -_atom_site.label_entity_id -_atom_site.label_seq_id -_atom_site.pdbx_PDB_ins_code -_atom_site.Cartn_x -_atom_site.Cartn_y -_atom_site.Cartn_z -_atom_site.occupancy -_atom_site.B_iso_or_equiv -_atom_site.pdbx_formal_charge -_atom_site.auth_seq_id -_atom_site.auth_comp_id -_atom_site.auth_asym_id -_atom_site.auth_atom_id -_atom_site.pdbx_PDB_model_num -""" - - if verbose: print(f"# {dataset_pdb_index} >{pdb_id}_{chain_id}") - log += f'# {dataset_pdb_index} >{pdb_id}_{chain_id} \n{value["sequence"]} {len(value["sequence"])}\n' - dataset_pdb_index += 1 - - seq = value["sequence"] - - atom_index = 1 - residue_index = 0 - - for residue in value["atoms"]: - residue_index += 1 # at the beginning of the cycle, so this will not mess up continues/break - for atom_name, xyz in residue.items(): - # xyz = ' '.join([str(f) for f in residue[atom_name]]) - x, y, z = xyz - - cif_txt_atom_site += format_line( - atom_index, - atom_name, - seq[ - residue_index - 1 - ], # residue_name, # which name to use it? seq[residue_index - 1] - chain_id, - residue_index, - x, - y, - z, - verbose=False, - ) - atom_index += 1 - - directory = f"{output_path}/{dataset}/{component}/{cluster}" - if not os.path.exists(directory): - os.makedirs(directory) - fn = ( - output_path - + f"/{dataset}/{component}/{cluster}/{pdb_id}_{chain_id}.cif" - ) - with open(fn, "w") as f: - if verbose: print(f"save {fn}") - f.write(cif_txt) - f.write(cif_txt_poly_seq.strip()) - f.write(cif_txt_atom_site) - - with open(f"{output_path}/rna3db.log", "w") as f: - f.write(log) + pbar = tqdm(total=num_chains) + + for set_name, set_content in data.items(): + for component_name, component_content in set_content.items(): + for cluster_name, cluster_content in component_content.items(): + cluster_path = ( + args.output_dir / set_name / component_name / cluster_name + ) + cluster_path.mkdir(exist_ok=True, parents=True) + for chain_name in cluster_content.keys(): + pdb_id, author_id = chain_name.split("_") + output_path = cluster_path / f"{chain_name}.cif" + pdb_mmcif_path = args.input_mmcif_dir / f"{pdb_id}.cif" + if not pdb_mmcif_path.is_file(): + print(f"WARNING: could not find {pdb_mmcif_path}") + continue + if not output_path.is_file() or not args.skip_existing: + sf = parse_file(pdb_mmcif_path, include_atoms=True) + sf.write_mmcif_chain(output_path, author_id) + pbar.update(1) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Converts a JSON produced by RNA3DB's parse command to a set of PDBx/mmCIF files." ) - parser.add_argument("input_json_file", type=Path) - parser.add_argument("output_folder", type=Path) + parser.add_argument("input_json_path", type=Path) + parser.add_argument("input_mmcif_dir", type=Path) + parser.add_argument("output_dir", type=Path) + parser.add_argument( + "--skip-existing", + action="store_true", + help="don't do write again if the output file already exists", + ) args = parser.parse_args() - print("loading json file ...", args.input_json_file) - data = read_json(args.input_json_file) - output_path = str(args.output_folder) - - run(data, output_path) + main(args) diff --git a/scripts/slurm/build_full_release.slurm b/scripts/slurm/build_full_release.slurm new file mode 100644 index 0000000..e7a1fc1 --- /dev/null +++ b/scripts/slurm/build_full_release.slurm @@ -0,0 +1,89 @@ +#!/bin/bash +#SBATCH -c 64 +#SBATCH -t 0 +#SBATCH -p +#SBATCH --mem=64000 +#SBATCH -o logs/rna3db_full_release_%j.out +#SBATCH -e logs/rna3db_full_release_%j.err +#SBATCH --mail-user= +#SBATCH --mail-type=ALL + +# where you want the release to be output to +OUTPUT_DIR="" + +# you set these once and forget +RNA3DB_ROOT_DIR="" +MMCIF_DIR="" +CMSCAN="" +CMDB="" + +NEW_RELEASE_DATE=$(date +"%Y-%m-%d") +mkdir -p $OUTPUT_DIR/$NEW_RELEASE_DATE + +# prepare the env +mamba activate rna3db + +# download latest mmCIF files +bash $RNA3DB_ROOT_DIR/scripts/download_pdb_mmcif.sh $MMCIF_DIR + +# run parse +python -m rna3db parse \ + $MMCIF_DIR \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/parse.json + +# run filter +python -m rna3db filter $MMCIF_DIR $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/filter.json + +# write all sequences to a FASTA file +python $RNA3DB_ROOT_DIR/scripts/json_to_fasta.py \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/parse.json \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.fasta + +# do cmscan on all new sequences +CMSCAN --cpu 64 \ + -o $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.o \ + -tbl $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.tbl \ + $CMDB \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.fasta + +# find all sequences that did not get a hit +python $RNA3DB_ROOT_DIR/scripts/get_nohits.py \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.fasta \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE-nohits.fasta \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/ + +# re-scan all sequences with --max that did not get a hit +CMSCAN --max --cpu 64 \ + -o $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE-nohits.o \ + -tbl $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE-nohits.tbl \ + $CMDB \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.fasta + +# run cluster +python -m rna3db cluster \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/filter.json \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/cluster.json \ + --tbl_dir $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-cmscans + +# run split +python -m rna3db split \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/cluster.json \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/split.json + +# make mmCIFs +python scripts/json_to_mmcif.py \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/split.json \ + /Users/marcell/Documents/rna3db/data/pdb_mmcif/ \ + $MMCIF_DIR \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-mmcifs + +# compress files ready for release +tar -czvf \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-cmscans.tar.gz \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-cmscans +tar -czvf \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-cmscans.tar.gz \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-cmscans +tar -cfvJ \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-mmcifs.tar.xz \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-mmcifs diff --git a/scripts/slurm/build_incremental_release.slurm b/scripts/slurm/build_incremental_release.slurm new file mode 100644 index 0000000..fd66c0c --- /dev/null +++ b/scripts/slurm/build_incremental_release.slurm @@ -0,0 +1,91 @@ +#!/bin/bash +#SBATCH -c 64 +#SBATCH -t 0 +#SBATCH -p +#SBATCH --mem=64000 +#SBATCH -o logs/rna3db_full_release_%j.out +#SBATCH -e logs/rna3db_full_release_%j.err +#SBATCH --mail-user= +#SBATCH --mail-type=ALL + +# the output dir, along with the date of the last release +OUTPUT_DIR="" +PREVIOUS_RELEASE_DATE="2024-04-26" + +# you set these once and forget +RNA3DB_ROOT_DIR="" +MMCIF_DIR="" +CMSCAN="" +CMDB="" + +NEW_RELEASE_DATE=$(date +"%Y-%m-%d") +mkdir -p $OUTPUT_DIR/$NEW_RELEASE_DATE + +# prepare the env +mamba activate rna3db + +# download latest mmCIF files +bash $RNA3DB_ROOT_DIR/scripts/download_pdb_mmcif.sh $MMCIF_DIR + +# run parse +python -m rna3db parse \ + $MMCIF_DIR \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/parse.json + +# run filter +python -m rna3db filter $MMCIF_DIR $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/filter.json + +# write only the new sequences to a FASTA file +python $RNA3DB_ROOT_DIR/scripts/build_incremental_release_fasta.py \ + $OUTPUT_DIR/$OLD_RELEASE_DATE/rna3db-jsons/parse.json \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/parse.json \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.fasta + +# do cmscan on all new sequences +CMSCAN --cpu 64 \ + -o $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.o \ + -tbl $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.tbl \ + $CMDB \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.fasta + +# find new sequences that did not get a hit +python $RNA3DB_ROOT_DIR/scripts/get_nohits.py \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.fasta \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE-nohits.fasta \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/ + +# re-scan new sequences with --max that did not get a hit +CMSCAN --max --cpu 64 \ + -o $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE-nohits.o \ + -tbl $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE-nohits.tbl \ + $CMDB \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/$NEW_RELEASE_DATE.fasta + +# run cluster +python -m rna3db cluster \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/filter.json \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/cluster.json \ + --tbl_dir $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-cmscans + +# run split +python -m rna3db split \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/cluster.json \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/split.json + +# make mmCIFs +python scripts/json_to_mmcif.py \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-jsons/split.json \ + /Users/marcell/Documents/rna3db/data/pdb_mmcif/ \ + $MMCIF_DIR \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-mmcifs + +# compress files ready for release +tar -czvf \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-cmscans.tar.gz \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-cmscans +tar -czvf \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-cmscans.tar.gz \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-cmscans +tar -cfvJ \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-mmcifs.tar.xz \ + $OUTPUT_DIR/$NEW_RELEASE_DATE/rna3db-mmcifs diff --git a/setup.py b/setup.py index 14474f9..e9351b6 100755 --- a/setup.py +++ b/setup.py @@ -2,9 +2,9 @@ setup( name="rna3db", - version=0.2, + version=1.1, description="A dataset for training and benchmarking deep learning models for RNA structure prediction", author="Marcell Szikszai", packages=find_packages(exclude=["tests", "scripts", "data"]), - install_requires=["biopython", "tqdm"], + install_requires=["biopython", "tqdm", "pre-commit"], ) diff --git a/tests/test_split.py b/tests/test_split.py new file mode 100644 index 0000000..a8a85e1 --- /dev/null +++ b/tests/test_split.py @@ -0,0 +1,71 @@ +import unittest + +from rna3db.split import find_optimal_components + +import unittest + + +class TestFindOptimalComponents(unittest.TestCase): + def test_small_lengths(self): + lengths_dict = {"component1": 1, "component2": 2, "component3": 3} + capacity = 5 + expected_components = set(["component2", "component3"]) + expected_length = sum( + lengths_dict[component] for component in expected_components + ) + result_components = find_optimal_components(lengths_dict, capacity) + result_length = sum(lengths_dict[component] for component in result_components) + self.assertEqual(result_length, expected_length) + self.assertCountEqual(result_components, expected_components) + + def test_exact_fit(self): + lengths_dict = {"component1": 4, "component2": 3, "component3": 2} + capacity = 7 + expected_components = set(["component1", "component2"]) + expected_length = sum( + lengths_dict[component] for component in expected_components + ) + result_components = find_optimal_components(lengths_dict, capacity) + result_length = sum(lengths_dict[component] for component in result_components) + self.assertEqual(result_length, expected_length) + self.assertCountEqual(result_components, expected_components) + + def test_over_capacity(self): + lengths_dict = {"component1": 8, "component2": 9, "component3": 10} + capacity = 5 + expected_components = set([]) + expected_length = sum( + lengths_dict[component] for component in expected_components + ) + result_components = find_optimal_components(lengths_dict, capacity) + result_length = sum(lengths_dict[component] for component in result_components) + self.assertEqual(result_length, expected_length) + self.assertEqual(result_components, expected_components) + + def test_zero_capacity(self): + lengths_dict = {"component1": 1, "component2": 2, "component3": 3} + capacity = 0 + expected_components = set([]) + expected_length = sum( + lengths_dict[component] for component in expected_components + ) + result_components = find_optimal_components(lengths_dict, capacity) + result_length = sum(lengths_dict[component] for component in result_components) + self.assertEqual(result_length, expected_length) + self.assertEqual(result_components, expected_components) + + def test_large_numbers(self): + lengths_dict = {"component1": 100, "component2": 300, "component3": 400} + capacity = 800 + expected_components = set(["component1", "component2", "component3"]) + expected_length = sum( + lengths_dict[component] for component in expected_components + ) + result_components = find_optimal_components(lengths_dict, capacity) + result_length = sum(lengths_dict[component] for component in result_components) + self.assertEqual(result_length, expected_length) + self.assertCountEqual(result_components, expected_components) + + +if __name__ == "__main__": + unittest.main()