Skip to content

Commit

Permalink
Improved optimal split solver and other fixes.
Browse files Browse the repository at this point in the history
- Solver for the placement of components into training/testing sets has been rewritten to find the optimal solution using integer linear programming
- Improved handling of `nmr_resolution` when multiple `structure_method`s in file
- Removed deprecated SCOPData equivalence tests
- Added new tests and general improvements to testing
- Improved the compatibility of the mmCIF parser/writer for release dates
  • Loading branch information
marcellszi committed Dec 4, 2024
1 parent 36e390d commit 8744e85
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 141 deletions.
7 changes: 5 additions & 2 deletions rna3db/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def write_mmcif_chain(self, output_path, author_id):
f"data_{self.pdb_id}_{author_id}\n"
f"_entry.id {self.pdb_id}_{author_id}\n"
f"_pdbx_database_status.recvd_initial_deposition_date {self.release_date}\n"
f"_pdbx_audit_revision_history.revision_date {self.release_date}\n" # this + above for better compatibility
f"_exptl.method '{self.structure_method.upper()}'\n"
f"_reflns.d_resolution_high {self.resolution}\n"
f"_entity_poly.pdbx_seq_one_letter_code_can {self[author_id].sequence}\n"
Expand Down Expand Up @@ -589,7 +590,9 @@ def release_date(self):
if "_pdbx_audit_revision_history.revision_date" in self.parsed_info:
return min(self.parsed_info["_pdbx_audit_revision_history.revision_date"])
# use deposition date if there are no revisions
return self.parsed_info["_pdbx_database_status.recvd_initial_deposition_date"]
return min(
self.parsed_info["_pdbx_database_status.recvd_initial_deposition_date"]
)

@property
def resolution(self):
Expand All @@ -604,7 +607,7 @@ def resolution(self):
resolutions.append(float(self.parsed_info[res_key][0]))

# if we have an NMR structure and we overwrite default NMR resolution
if self.structure_method == "solution nmr" and self.nmr_resolution is not None:
if "solution nmr" in self.structure_method and self.nmr_resolution is not None:
return self.nmr_resolution

if len(resolutions) == 0:
Expand Down
122 changes: 92 additions & 30 deletions rna3db/split.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,70 @@
import random
import pulp

from typing import Sequence

from rna3db.utils import PathLike, read_json, write_json


def find_optimal_components(lengths_dict, capacity):
component_name = list(lengths_dict.keys())
lengths = list(lengths_dict.values())
def find_optimal_components(
components: Sequence[int], bins: Sequence[int], verbose: bool = False
) -> Sequence[set[int]]:
"""Function used to find optimal placement of components into
training/testing sets.
dp = [0] * (capacity + 1)
trace = [[] for i in range(capacity + 1)]
for i in range(len(lengths)):
for j in range(capacity, lengths[i] - 1, -1):
if dp[j] < dp[j - lengths[i]] + lengths[i]:
dp[j] = dp[j - lengths[i]] + lengths[i]
trace[j] = trace[j - lengths[i]] + [component_name[i]]
We use an ILP formulation that is very similar to the classic ILP
formulation of the bin packing problem.
return set(trace[capacity])
Args:
components (Sequence[int]): list of component sizes
bins (Sequence[int]): list of bin sizes
verbose (bool): whether to print verbose output
Returns:
Sequence[set[int]]: list of sets, where each set contains the indices
of the components that go into that bin
"""

n, k = len(components), len(bins)

# set up problem
p = pulp.LpProblem("OptimalComponentSolver", pulp.LpMinimize)
x = pulp.LpVariable.dicts(
"x", ((i, j) for i in range(n) for j in range(k)), cat="Binary"
)
deviation = pulp.LpVariable.dicts(
"d", (j for j in range(k)), lowBound=0, cat="Continuous"
)

# we want to minimise total "deviation"
# (deviation is the total sum of the difference between target bins and found bins)
p += pulp.lpSum(deviation[j] for j in range(k))

# components can go into exactly one bin
for i in range(n):
p += pulp.lpSum(x[(i, j)] for j in range(k)) == 1, f"AssignComponent_{i}"

# deviation constraints (to handle abs)
for j in range(k):
total_weight_in_bin = pulp.lpSum(components[i] * x[(i, j)] for i in range(n))
p += total_weight_in_bin - bins[j] <= deviation[j], f"DeviationPos_{j}"
p += bins[j] - total_weight_in_bin <= deviation[j], f"DeviationNeg_{j}"

# solve ILP problem with PuLP
p.solve(pulp.PULP_CBC_CMD(msg=int(verbose)))

# extract solution in sensible format
sol = [set() for i in range(k)]
for i in range(k):
for j in range(n):
if pulp.value(x[(j, i)]) == 1:
sol[i].add(j)

return sol


def split(
input_path: PathLike,
output_path: PathLike,
output_path: PathLike = None,
splits: Sequence[float] = [0.7, 0.0, 0.3],
split_names: Sequence[str] = ["train_set", "valid_set", "test_set"],
shuffle: bool = False,
Expand All @@ -41,33 +83,53 @@ def split(
if sum(splits) != 1.0:
raise ValueError("Sum of splits must equal 1.0.")

# read json
if len(splits) != len(split_names):
raise ValueError("Number of splits must match number of split names.")

cluster_json = read_json(input_path)

# count number of repr sequences
lengths = {k: len(v) for k, v in cluster_json.items()}
total_repr_clusters = sum(lengths.values())
# get lengths of the components, and mapping from idx to keys
keys, lengths = [], []
for k, v in cluster_json.items():
if force_zero_last and k == "component_0":
continue
keys.append(k)
lengths.append(len(v))

# shuffle if we want to add randomness
if shuffle:
L = list(zip(component_name, lengths))
random.shuffle(L)
component_name, lengths = zip(*L)
component_name, lengths = list(component_name), list(lengths)
# calculate actual bin capacities
# rounding is probably close enough
bins = [round(sum(lengths) * ratio) for ratio in splits]

# create output dict
output = {k: {} for k in split_names}

# force `component_0` into the last bin
if force_zero_last:
if bins[-1] < len(cluster_json["component_0"]):
print(
"ERROR: cannot force `component_0` into the last bin. Increase the last bin size."
)
raise ValueError
bins[-1] -= len(cluster_json["component_0"])
output[split_names[-1]]["component_0"] = cluster_json["component_0"]
lengths.pop("component_0")
del cluster_json["component_0"]

if shuffle:
L = list(zip(keys, lengths))
random.shuffle(L)
keys, lengths = zip(*L)
keys, lengths = list(keys), list(lengths)

# find optimal split with ILP
sol = find_optimal_components(lengths, bins)

capacities = [round(total_repr_clusters * ratio) for ratio in splits]
for name, capacity in zip(split_names, capacities):
components = find_optimal_components(lengths, capacity)
for k in sorted(components):
lengths.pop(k)
# write output to dict
for idx, name in enumerate(split_names):
for k in sorted(sol[idx]):
k = keys[k]
output[name][k] = cluster_json[k]

assert len(lengths) == 0
if output_path:
write_json(output, output_path)

write_json(output, output_path)
return output
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
description="A dataset for training and benchmarking deep learning models for RNA structure prediction",
author="Marcell Szikszai",
packages=find_packages(exclude=["tests", "scripts", "data"]),
install_requires=["biopython", "tqdm", "black", "pre-commit"],
install_requires=["biopython", "tqdm", "black", "pre-commit", "pulp"],
)
53 changes: 5 additions & 48 deletions tests/test_modifications.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,20 @@
from Bio.Data import PDBData
from pathlib import Path
import unittest
import pickle

from rna3db.parser import ModificationHandler
from Bio.Data import SCOPData

# import PDBData


class TestModifications(unittest.TestCase):
modification_handler = ModificationHandler(
"tests/test_data/modifications_cache.json"
Path(__file__).parent / "test_data" / "modifications_cache.json"
)

# accept both A and U for selenocysteines
# note that selenocysteines that are incorrectly encoded in SCOPData
# shouldn't be included here
selenocysteines = ["SEC"]

def test_scop_equivalence(self):
for k in SCOPData.protein_letters_3to1.keys():
k = k.rstrip() # to address biopython's weird whitespace

scop_code = SCOPData.protein_letters_3to1.get(k, "X")
scop_code = scop_code if len(scop_code) == 1 else "X"
protein_code = self.modification_handler.protein_letters_3to1(k)

# if it's not a protein we don't care about equivalence
if not self.modification_handler.is_protein(k):
continue

# selenocysteines (U) are treated as A by AlphaFold
# so we don't care if they are encoded as A in SCOPData
if k in self.selenocysteines:
self.assertTrue(protein_code == "A" or protein_code == "U")
continue

# these two modifications were fixed in a later version of
# biopython so we check them manually
if k == "4BF":
self.assertEqual(protein_code, "F")
continue
if k == "PCA":
self.assertEqual(protein_code, "Q")
continue

# we are happy to recover more unknowns than SCOPData
self.assertTrue(
protein_code == scop_code or "X" == scop_code,
)

def test_biopython_coverage(self):
for k, v in SCOPData.protein_letters_3to1.items():
for k, v in PDBData.nucleic_letters_3to1_extended.items():
k = k.rstrip() # to address biopython's weird whitespace

if self.modification_handler.is_protein(k):
self.assertTrue(
v == "X" or self.modification_handler.protein_letters_3to1(k) != "X"
)

# we just check that we have all PDBData modifications at least
if self.modification_handler.is_rna(k):
self.assertTrue(
v == "N" or self.modification_handler.rna_letters_3to1(k) != "N"
Expand Down
Loading

0 comments on commit 8744e85

Please sign in to comment.