Skip to content

Commit 855e899

Browse files
committed
Improved splitting, and preparing for new release.
- Improved the way splitting works - Added SLURM scripts to simplify generating releases in the future - Modified a bunch of scripts to simplify making new releases - Moved the mmCIF generation code to rna3db/parser.py - Added pre-commits - Modified CITATION.cff to reflect JMB article
1 parent de8df71 commit 855e899

13 files changed

+621
-256
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ data/
1313
[._]ss[a-gi-z]
1414
[._]sw[a-p]
1515

16-
**/.DS_Store
16+
**/.DS_Store

.pre-commit-config.yaml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
repos:
2+
- repo: local
3+
hooks:
4+
- id: unittests
5+
name: run unit tests
6+
entry: python -m unittest
7+
language: system
8+
pass_filenames: false
9+
args: ["discover"]
10+
- repo: https://github.com/pre-commit/pre-commit-hooks
11+
rev: v2.3.0
12+
hooks:
13+
- id: check-yaml
14+
- id: end-of-file-fixer
15+
- id: trailing-whitespace
16+
- repo: https://github.com/psf/black
17+
rev: 24.3.0
18+
hooks:
19+
- id: black

CITATION.cff

+22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
cff-version: 1.2.0
22
message: "If you use this software, please cite it as below."
33
title: "RNA3DB: A dataset for training and benchmarking deep learning models for RNA structure prediction"
4+
version: 1.1
45
authors:
56
- given-names: "Marcell"
67
family-names: "Szikszai"
@@ -15,3 +16,24 @@ authors:
1516
- given-names: "Elena
1617
family-names: Rivas"
1718
url: "https://github.com/marcellszi/rna3db"
19+
doi: "10.1016/j.jmb.2024.168552"
20+
date-released: 2024-04-26
21+
preferred-citation:
22+
type: article
23+
authors:
24+
- given-names: "Marcell"
25+
family-names: "Szikszai"
26+
- given-names: "Marcin"
27+
family-names: Magnus
28+
- given-names: "Siddhant"
29+
family-names: "Sanghi"
30+
- given-names: "Sachin"
31+
family-names: "Kadyan"
32+
- given-names: "Nazim"
33+
family-names: "Bouatta"
34+
- given-names: "Elena"
35+
family-names: Rivas"
36+
doi: "10.1016/j.jmb.2024.168552"
37+
journal: "Journal of Molecular Biology"
38+
title: "RNA3DB: A structurally-dissimilar dataset split for training and benchmarking deep learning models for RNA structure prediction"
39+
year: 2024

rna3db/__main__.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,16 @@ def main(args):
132132
args.input, args.output, args.tbl_dir, args.structural_e_value_cutoff
133133
)
134134
elif args.command == "split":
135-
split(args.input, args.output, args.train_percentage, args.force_zero_test)
135+
split(
136+
args.input,
137+
args.output,
138+
splits=[
139+
args.train_ratio,
140+
args.valid_ratio,
141+
1 - args.train_ratio - args.valid_ratio,
142+
],
143+
force_zero_last=args.force_zero_test,
144+
)
136145
else:
137146
raise ValueError
138147

@@ -246,10 +255,16 @@ def main(args):
246255
split_parser.add_argument("input", type=Path, help="Input JSON file")
247256
split_parser.add_argument("output", type=Path, help="Output JSON file")
248257
split_parser.add_argument(
249-
"--train_percentage",
258+
"--train_ratio",
250259
type=float,
251-
default=0.3,
252-
help="Percentage of data for the train set",
260+
default=0.7,
261+
help="Ratio of data to use for the training set",
262+
)
263+
split_parser.add_argument(
264+
"--valid_ratio",
265+
type=float,
266+
default=0.0,
267+
help="Ratio of the data to use for the validation set",
253268
)
254269
split_parser.add_argument(
255270
"--force_zero_test",

rna3db/parser.py

+180
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def __getitem__(self, idx):
135135
def __len__(self):
136136
return len(self.residues)
137137

138+
@property
139+
def has_atoms(self):
140+
return any([not res.is_missing for res in self])
141+
138142
def add_residue(self, res: Residue):
139143
"""Add a residue to the chain.
140144
@@ -341,6 +345,182 @@ def __repr__(self):
341345
f"resolution={self.resolution}, release_date={self.release_date}, structure_method={self.structure_method})"
342346
)
343347

348+
@staticmethod
349+
def _gen_mmcif_loop_str(name: str, headers: Sequence[str], values: Sequence[tuple]):
350+
s = "#\nloop_\n"
351+
for header in headers:
352+
s += f"_{name}.{header}\n"
353+
354+
max_widths = {k: 0 for k in headers}
355+
for V in values:
356+
for k, v in zip(headers, V):
357+
max_widths[k] = max(max_widths[k], len(str(v)))
358+
359+
for V in values:
360+
row = ""
361+
for k, v in zip(headers, V):
362+
row += f"{str(v):<{max_widths[k]}} "
363+
s += row + "\n"
364+
365+
return s
366+
367+
def write_mmcif_chain(self, output_path, author_id):
368+
if not self[author_id].has_atoms:
369+
raise ValueError(
370+
f"Did not find any atoms for chain {author_id}. Did you set `include_atoms=True`?"
371+
)
372+
# extract needed info
373+
entity_poly_seq_data = []
374+
atom_site_data = []
375+
for i, res in enumerate(self[author_id]):
376+
entity_poly_seq_data.append((1, res.index + 1, res.code, "n"))
377+
for idx, (atom_name, atom_coords) in enumerate(res.atoms.items()):
378+
x, y, z = atom_coords
379+
atom_site_data.append(
380+
(
381+
"ATOM",
382+
idx + 1,
383+
atom_name[0],
384+
atom_name,
385+
".",
386+
res.code,
387+
author_id,
388+
"?",
389+
i + 1,
390+
"?",
391+
x,
392+
y,
393+
z,
394+
1.0,
395+
0.0,
396+
"?",
397+
i + 1,
398+
res.code,
399+
author_id,
400+
atom_name,
401+
1,
402+
)
403+
)
404+
405+
# build required strings
406+
header_str = (
407+
f"# generated by rna3db\n"
408+
f"#\n"
409+
f"data_{self.pdb_id}_{author_id}\n"
410+
f"_entry.id {self.pdb_id}_{author_id}\n"
411+
f"_pdbx_database_status.recvd_initial_deposition_date {self.release_date}\n"
412+
f"_exptl.method '{self.structure_method.upper()}'\n"
413+
f"_reflns.d_resolution_high {self.resolution}\n"
414+
f"_entity_poly.pdbx_seq_one_letter_code_can {self[author_id].sequence}\n"
415+
)
416+
struct_asym_str = StructureFile._gen_mmcif_loop_str(
417+
"_struct_asym",
418+
[
419+
"id",
420+
"pdbx_blank_PDB_chainid_flag",
421+
"pdbx_modified",
422+
"entity_id",
423+
"details",
424+
],
425+
[("A", "N", "N", 1, "?")],
426+
)
427+
chem_comp_str = StructureFile._gen_mmcif_loop_str(
428+
"_chem_comp",
429+
[
430+
"id",
431+
"type",
432+
"mon_nstd_flag",
433+
"pdbx_synonyms",
434+
"formula",
435+
"formula_weight",
436+
],
437+
[
438+
(
439+
"A",
440+
"'RNA linking'",
441+
"y",
442+
'"ADENOSINE-5\'-MONOPHOSPHATE"',
443+
"?",
444+
"'C10 H14 N5 O7 P'",
445+
347.221,
446+
),
447+
(
448+
"C",
449+
"'RNA linking'",
450+
"y",
451+
'"CYTIDINE-5\'-MONOPHOSPHATE"',
452+
"?",
453+
"'C9 H14 N3 O8 P'",
454+
323.197,
455+
),
456+
(
457+
"G",
458+
"'RNA linking'",
459+
"y",
460+
'"GUANOSINE-5\'-MONOPHOSPHATE"',
461+
"?",
462+
"'C9 H13 N2 O9 P'",
463+
363.221,
464+
),
465+
(
466+
"U",
467+
"'RNA linking'",
468+
"y",
469+
'"URIDINE-5\'-MONOPHOSPHATE"',
470+
"?",
471+
"'C9 H13 N2 O9 P'",
472+
324.181,
473+
),
474+
("T", "'RNA linking'", "y", '"T"', "?", "''", 0),
475+
("N", "'RNA linking'", "y", '"N"', "?", "''", 0),
476+
],
477+
)
478+
entity_poly_seq_str = StructureFile._gen_mmcif_loop_str(
479+
"entity_poly_seq",
480+
[
481+
"entity_id",
482+
"num",
483+
"mon_id",
484+
"heter",
485+
],
486+
entity_poly_seq_data,
487+
)
488+
atom_site_str = StructureFile._gen_mmcif_loop_str(
489+
"atom_site",
490+
[
491+
"group_PDB",
492+
"id",
493+
"type_symbol",
494+
"label_atom_id",
495+
"label_alt_id",
496+
"label_comp_id",
497+
"label_asym_id",
498+
"label_entity_id",
499+
"label_seq_id",
500+
"pdbx_PDB_ins_code",
501+
"Cartn_x",
502+
"Cartn_y",
503+
"Cartn_z",
504+
"occupancy",
505+
"B_iso_or_equiv",
506+
"pdbx_formal_charge",
507+
"auth_seq_id",
508+
"auth_comp_id",
509+
"auth_asym_id",
510+
"auth_atom_id",
511+
"pdbx_PDB_model_num",
512+
],
513+
atom_site_data,
514+
)
515+
516+
# write to file
517+
with open(output_path, "w") as f:
518+
f.write(header_str)
519+
f.write(struct_asym_str)
520+
f.write(chem_comp_str)
521+
f.write(entity_poly_seq_str)
522+
f.write(atom_site_str)
523+
344524

345525
class mmCIFParser:
346526
def __init__(

rna3db/split.py

+55-28
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,32 @@
1+
import random
2+
3+
from typing import Sequence
4+
15
from rna3db.utils import PathLike, read_json, write_json
26

37

8+
def find_optimal_components(lengths_dict, capacity):
9+
component_name = list(lengths_dict.keys())
10+
lengths = list(lengths_dict.values())
11+
12+
dp = [0] * (capacity + 1)
13+
trace = [[] for i in range(capacity + 1)]
14+
for i in range(len(lengths)):
15+
for j in range(capacity, lengths[i] - 1, -1):
16+
if dp[j] < dp[j - lengths[i]] + lengths[i]:
17+
dp[j] = dp[j - lengths[i]] + lengths[i]
18+
trace[j] = trace[j - lengths[i]] + [component_name[i]]
19+
20+
return set(trace[capacity])
21+
22+
423
def split(
524
input_path: PathLike,
625
output_path: PathLike,
7-
train_size: float = 0.7,
8-
force_zero_test: bool = True,
26+
splits: Sequence[float] = [0.7, 0.0, 0.3],
27+
split_names: Sequence[str] = ["train_set", "valid_set", "test_set"],
28+
shuffle: bool = False,
29+
force_zero_last: bool = False,
930
):
1031
"""A function that splits a JSON of components into a train/test set.
1132
@@ -16,35 +37,41 @@ def split(
1637
Args:
1738
input_path (PathLike): path to JSON containing components
1839
output_path (PathLike): path to output JSON
19-
train_size (float): percentage of data to use as training set
20-
force_zero_test (bool): whether to force component_0 into the test set
2140
"""
41+
if sum(splits) != 1.0:
42+
raise ValueError("Sum of splits must equal 1.0.")
43+
2244
# read json
2345
cluster_json = read_json(input_path)
46+
2447
# count number of repr sequences
25-
total_repr_clusters = sum(len(v) for v in cluster_json.values())
26-
27-
# figure out which components need to go into training set
28-
train_components = set()
29-
train_set_length = 0
30-
i = 1 if force_zero_test else 0
31-
while train_set_length / total_repr_clusters < train_size:
32-
# skip if it's not a real component (should only happen with 0)
33-
if f"component_{i}" not in cluster_json:
34-
i += 1
35-
continue
36-
train_components.add(f"component_{i}")
37-
train_set_length += len(cluster_json[f"component_{i}"].keys())
38-
i += 1
39-
40-
# test_components are just total-train_components
41-
test_components = set(cluster_json.keys()) - train_components
42-
43-
# actually build JSON
44-
output = {"train_set": {}, "test_set": {}}
45-
for k in sorted(train_components):
46-
output["train_set"][k] = cluster_json[k]
47-
for k in sorted(test_components):
48-
output["test_set"][k] = cluster_json[k]
48+
lengths = {k: len(v) for k, v in cluster_json.items()}
49+
total_repr_clusters = sum(lengths.values())
50+
51+
# shuffle if we want to add randomness
52+
if shuffle:
53+
L = list(zip(component_name, lengths))
54+
random.shuffle(L)
55+
component_name, lengths = zip(*L)
56+
component_name, lengths = list(component_name), list(lengths)
57+
58+
output = {k: {} for k in split_names}
59+
60+
if force_zero_last:
61+
output[split_names[-1]]["component_0"] = cluster_json["component_0"]
62+
lengths.pop("component_0")
63+
64+
# start with the smallest splits
65+
# splits, split_names = zip(*sorted(zip(splits, split_names), reverse=True))
66+
67+
capacities = [int(total_repr_clusters * ratio) for ratio in splits]
68+
for name, capacity in zip(split_names, capacities):
69+
components = find_optimal_components(lengths, capacity)
70+
print(name, capacity, components)
71+
for k in sorted(components):
72+
lengths.pop(k)
73+
output[name][k] = cluster_json[k]
74+
75+
assert len(lengths) == 0
4976

5077
write_json(output, output_path)

0 commit comments

Comments
 (0)