Skip to content

Commit 8744e85

Browse files
committed
Improved optimal split solver and other fixes.
- Solver for the placement of components into training/testing sets has been rewritten to find the optimal solution using integer linear programming - Improved handling of `nmr_resolution` when multiple `structure_method`s in file - Removed deprecated SCOPData equivalence tests - Added new tests and general improvements to testing - Improved the compatibility of the mmCIF parser/writer for release dates
1 parent 36e390d commit 8744e85

File tree

5 files changed

+297
-141
lines changed

5 files changed

+297
-141
lines changed

rna3db/parser.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ def write_mmcif_chain(self, output_path, author_id):
431431
f"data_{self.pdb_id}_{author_id}\n"
432432
f"_entry.id {self.pdb_id}_{author_id}\n"
433433
f"_pdbx_database_status.recvd_initial_deposition_date {self.release_date}\n"
434+
f"_pdbx_audit_revision_history.revision_date {self.release_date}\n" # this + above for better compatibility
434435
f"_exptl.method '{self.structure_method.upper()}'\n"
435436
f"_reflns.d_resolution_high {self.resolution}\n"
436437
f"_entity_poly.pdbx_seq_one_letter_code_can {self[author_id].sequence}\n"
@@ -589,7 +590,9 @@ def release_date(self):
589590
if "_pdbx_audit_revision_history.revision_date" in self.parsed_info:
590591
return min(self.parsed_info["_pdbx_audit_revision_history.revision_date"])
591592
# use deposition date if there are no revisions
592-
return self.parsed_info["_pdbx_database_status.recvd_initial_deposition_date"]
593+
return min(
594+
self.parsed_info["_pdbx_database_status.recvd_initial_deposition_date"]
595+
)
593596

594597
@property
595598
def resolution(self):
@@ -604,7 +607,7 @@ def resolution(self):
604607
resolutions.append(float(self.parsed_info[res_key][0]))
605608

606609
# if we have an NMR structure and we overwrite default NMR resolution
607-
if self.structure_method == "solution nmr" and self.nmr_resolution is not None:
610+
if "solution nmr" in self.structure_method and self.nmr_resolution is not None:
608611
return self.nmr_resolution
609612

610613
if len(resolutions) == 0:

rna3db/split.py

+92-30
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,70 @@
11
import random
2+
import pulp
23

34
from typing import Sequence
45

56
from rna3db.utils import PathLike, read_json, write_json
67

78

8-
def find_optimal_components(lengths_dict, capacity):
9-
component_name = list(lengths_dict.keys())
10-
lengths = list(lengths_dict.values())
9+
def find_optimal_components(
10+
components: Sequence[int], bins: Sequence[int], verbose: bool = False
11+
) -> Sequence[set[int]]:
12+
"""Function used to find optimal placement of components into
13+
training/testing sets.
1114
12-
dp = [0] * (capacity + 1)
13-
trace = [[] for i in range(capacity + 1)]
14-
for i in range(len(lengths)):
15-
for j in range(capacity, lengths[i] - 1, -1):
16-
if dp[j] < dp[j - lengths[i]] + lengths[i]:
17-
dp[j] = dp[j - lengths[i]] + lengths[i]
18-
trace[j] = trace[j - lengths[i]] + [component_name[i]]
15+
We use an ILP formulation that is very similar to the classic ILP
16+
formulation of the bin packing problem.
1917
20-
return set(trace[capacity])
18+
Args:
19+
components (Sequence[int]): list of component sizes
20+
bins (Sequence[int]): list of bin sizes
21+
verbose (bool): whether to print verbose output
22+
Returns:
23+
Sequence[set[int]]: list of sets, where each set contains the indices
24+
of the components that go into that bin
25+
"""
26+
27+
n, k = len(components), len(bins)
28+
29+
# set up problem
30+
p = pulp.LpProblem("OptimalComponentSolver", pulp.LpMinimize)
31+
x = pulp.LpVariable.dicts(
32+
"x", ((i, j) for i in range(n) for j in range(k)), cat="Binary"
33+
)
34+
deviation = pulp.LpVariable.dicts(
35+
"d", (j for j in range(k)), lowBound=0, cat="Continuous"
36+
)
37+
38+
# we want to minimise total "deviation"
39+
# (deviation is the total sum of the difference between target bins and found bins)
40+
p += pulp.lpSum(deviation[j] for j in range(k))
41+
42+
# components can go into exactly one bin
43+
for i in range(n):
44+
p += pulp.lpSum(x[(i, j)] for j in range(k)) == 1, f"AssignComponent_{i}"
45+
46+
# deviation constraints (to handle abs)
47+
for j in range(k):
48+
total_weight_in_bin = pulp.lpSum(components[i] * x[(i, j)] for i in range(n))
49+
p += total_weight_in_bin - bins[j] <= deviation[j], f"DeviationPos_{j}"
50+
p += bins[j] - total_weight_in_bin <= deviation[j], f"DeviationNeg_{j}"
51+
52+
# solve ILP problem with PuLP
53+
p.solve(pulp.PULP_CBC_CMD(msg=int(verbose)))
54+
55+
# extract solution in sensible format
56+
sol = [set() for i in range(k)]
57+
for i in range(k):
58+
for j in range(n):
59+
if pulp.value(x[(j, i)]) == 1:
60+
sol[i].add(j)
61+
62+
return sol
2163

2264

2365
def split(
2466
input_path: PathLike,
25-
output_path: PathLike,
67+
output_path: PathLike = None,
2668
splits: Sequence[float] = [0.7, 0.0, 0.3],
2769
split_names: Sequence[str] = ["train_set", "valid_set", "test_set"],
2870
shuffle: bool = False,
@@ -41,33 +83,53 @@ def split(
4183
if sum(splits) != 1.0:
4284
raise ValueError("Sum of splits must equal 1.0.")
4385

44-
# read json
86+
if len(splits) != len(split_names):
87+
raise ValueError("Number of splits must match number of split names.")
88+
4589
cluster_json = read_json(input_path)
4690

47-
# count number of repr sequences
48-
lengths = {k: len(v) for k, v in cluster_json.items()}
49-
total_repr_clusters = sum(lengths.values())
91+
# get lengths of the components, and mapping from idx to keys
92+
keys, lengths = [], []
93+
for k, v in cluster_json.items():
94+
if force_zero_last and k == "component_0":
95+
continue
96+
keys.append(k)
97+
lengths.append(len(v))
5098

51-
# shuffle if we want to add randomness
52-
if shuffle:
53-
L = list(zip(component_name, lengths))
54-
random.shuffle(L)
55-
component_name, lengths = zip(*L)
56-
component_name, lengths = list(component_name), list(lengths)
99+
# calculate actual bin capacities
100+
# rounding is probably close enough
101+
bins = [round(sum(lengths) * ratio) for ratio in splits]
57102

103+
# create output dict
58104
output = {k: {} for k in split_names}
59105

106+
# force `component_0` into the last bin
60107
if force_zero_last:
108+
if bins[-1] < len(cluster_json["component_0"]):
109+
print(
110+
"ERROR: cannot force `component_0` into the last bin. Increase the last bin size."
111+
)
112+
raise ValueError
113+
bins[-1] -= len(cluster_json["component_0"])
61114
output[split_names[-1]]["component_0"] = cluster_json["component_0"]
62-
lengths.pop("component_0")
115+
del cluster_json["component_0"]
116+
117+
if shuffle:
118+
L = list(zip(keys, lengths))
119+
random.shuffle(L)
120+
keys, lengths = zip(*L)
121+
keys, lengths = list(keys), list(lengths)
122+
123+
# find optimal split with ILP
124+
sol = find_optimal_components(lengths, bins)
63125

64-
capacities = [round(total_repr_clusters * ratio) for ratio in splits]
65-
for name, capacity in zip(split_names, capacities):
66-
components = find_optimal_components(lengths, capacity)
67-
for k in sorted(components):
68-
lengths.pop(k)
126+
# write output to dict
127+
for idx, name in enumerate(split_names):
128+
for k in sorted(sol[idx]):
129+
k = keys[k]
69130
output[name][k] = cluster_json[k]
70131

71-
assert len(lengths) == 0
132+
if output_path:
133+
write_json(output, output_path)
72134

73-
write_json(output, output_path)
135+
return output

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
description="A dataset for training and benchmarking deep learning models for RNA structure prediction",
77
author="Marcell Szikszai",
88
packages=find_packages(exclude=["tests", "scripts", "data"]),
9-
install_requires=["biopython", "tqdm", "black", "pre-commit"],
9+
install_requires=["biopython", "tqdm", "black", "pre-commit", "pulp"],
1010
)

tests/test_modifications.py

+5-48
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,20 @@
1+
from Bio.Data import PDBData
2+
from pathlib import Path
13
import unittest
2-
import pickle
34

45
from rna3db.parser import ModificationHandler
5-
from Bio.Data import SCOPData
6-
7-
# import PDBData
86

97

108
class TestModifications(unittest.TestCase):
119
modification_handler = ModificationHandler(
12-
"tests/test_data/modifications_cache.json"
10+
Path(__file__).parent / "test_data" / "modifications_cache.json"
1311
)
1412

15-
# accept both A and U for selenocysteines
16-
# note that selenocysteines that are incorrectly encoded in SCOPData
17-
# shouldn't be included here
18-
selenocysteines = ["SEC"]
19-
20-
def test_scop_equivalence(self):
21-
for k in SCOPData.protein_letters_3to1.keys():
22-
k = k.rstrip() # to address biopython's weird whitespace
23-
24-
scop_code = SCOPData.protein_letters_3to1.get(k, "X")
25-
scop_code = scop_code if len(scop_code) == 1 else "X"
26-
protein_code = self.modification_handler.protein_letters_3to1(k)
27-
28-
# if it's not a protein we don't care about equivalence
29-
if not self.modification_handler.is_protein(k):
30-
continue
31-
32-
# selenocysteines (U) are treated as A by AlphaFold
33-
# so we don't care if they are encoded as A in SCOPData
34-
if k in self.selenocysteines:
35-
self.assertTrue(protein_code == "A" or protein_code == "U")
36-
continue
37-
38-
# these two modifications were fixed in a later version of
39-
# biopython so we check them manually
40-
if k == "4BF":
41-
self.assertEqual(protein_code, "F")
42-
continue
43-
if k == "PCA":
44-
self.assertEqual(protein_code, "Q")
45-
continue
46-
47-
# we are happy to recover more unknowns than SCOPData
48-
self.assertTrue(
49-
protein_code == scop_code or "X" == scop_code,
50-
)
51-
5213
def test_biopython_coverage(self):
53-
for k, v in SCOPData.protein_letters_3to1.items():
14+
for k, v in PDBData.nucleic_letters_3to1_extended.items():
5415
k = k.rstrip() # to address biopython's weird whitespace
5516

56-
if self.modification_handler.is_protein(k):
57-
self.assertTrue(
58-
v == "X" or self.modification_handler.protein_letters_3to1(k) != "X"
59-
)
60-
17+
# we just check that we have all PDBData modifications at least
6118
if self.modification_handler.is_rna(k):
6219
self.assertTrue(
6320
v == "N" or self.modification_handler.rna_letters_3to1(k) != "N"

0 commit comments

Comments
 (0)