Skip to content

Commit dab29d0

Browse files
committed
Improvements to testing and parser
- Added `__eq__` functionality to `parser.Residue` and `parser.Chain` - Fixed issue with reading/writing files with new `_entity_poly.type` chain type determination - Added several tests for the parser, including reading/writing tests
1 parent 5567f82 commit dab29d0

File tree

5 files changed

+7787
-24
lines changed

5 files changed

+7787
-24
lines changed

rna3db/parser.py

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,15 @@ def __init__(
9494
three_letter_code: str,
9595
one_letter_code: str,
9696
index: int,
97+
atoms: dict = None,
9798
):
9899
self.three_letter_code = three_letter_code
99100
self.one_letter_code = one_letter_code
100101
self.index = index
101-
self.atoms = {}
102+
# NOTE: we need to handle dict like this, cannot use `atoms: dict = {}` in method definition
103+
# See important warning: https://docs.python.org/3/tutorial/controlflow.html#default-argument-values
104+
# (the default value is evaluated only once, which causes issues with mutable dictionaries)
105+
self.atoms = atoms if atoms else {}
102106

103107
@property
104108
def code(self) -> str:
@@ -108,6 +112,15 @@ def code(self) -> str:
108112
def is_missing(self) -> bool:
109113
return not len(self.atoms) > 0
110114

115+
def __eq__(self, other) -> bool:
116+
# NOTE: we don't care about three letter codes, only one letter
117+
# this means modifications are still equal
118+
return (
119+
self.one_letter_code == other.one_letter_code
120+
and self.index == other.index
121+
and self.atoms == other.atoms
122+
)
123+
111124
def __repr__(self):
112125
return (
113126
f"Residue(code={self.code}, three_letter_code={self.three_letter_code}, "
@@ -135,6 +148,17 @@ def __getitem__(self, idx):
135148
def __len__(self):
136149
return len(self.residues)
137150

151+
def __eq__(self, other):
152+
# NOTE: we ignore the author_id for equality checks
153+
if len(self) != len(other):
154+
return False
155+
156+
for res_self, res_other in zip(self, other):
157+
if res_self != res_other:
158+
return False
159+
160+
return True
161+
138162
@property
139163
def has_atoms(self):
140164
return any([not res.is_missing for res in self])
@@ -314,9 +338,7 @@ def __init__(
314338
)
315339
file_parser = PDBParser
316340
else:
317-
raise ValueError(
318-
f"The extension {self.path.suffix.lower()} is not supported."
319-
)
341+
raise ValueError(f"The extension `{path.suffix.lower()}` is not supported.")
320342

321343
# make the parser
322344
parser = file_parser(
@@ -475,6 +497,15 @@ def write_mmcif_chain(self, output_path, author_id):
475497
("N", "'RNA linking'", "y", '"N"', "?", "''", 0),
476498
],
477499
)
500+
entity_poly = StructureFile._gen_mmcif_loop_str(
501+
"entity_poly",
502+
[
503+
"entity_id",
504+
"type",
505+
],
506+
[(1, "polyribonucleotide")],
507+
)
508+
478509
entity_poly_seq_str = StructureFile._gen_mmcif_loop_str(
479510
"entity_poly_seq",
480511
[
@@ -518,6 +549,7 @@ def write_mmcif_chain(self, output_path, author_id):
518549
f.write(header_str)
519550
f.write(struct_asym_str)
520551
f.write(chem_comp_str)
552+
f.write(entity_poly)
521553
f.write(entity_poly_seq_str)
522554
f.write(atom_site_str)
523555

@@ -648,14 +680,6 @@ def chains(self):
648680
k = mmcif_chain_to_entity_id[mmcif_chain_id]
649681
id_map[k].add(author_chain_id)
650682

651-
# get the chem_comp type for each mon_id
652-
chem_comp_type = {
653-
mon_id: comp_type
654-
for mon_id, comp_type in zip(
655-
self.parsed_info["_chem_comp.id"], self.parsed_info["_chem_comp.type"]
656-
)
657-
}
658-
659683
# parse full chains from "seqres"
660684
chains_full = defaultdict(Chain)
661685
for entity_id, mon_id, idx in zip(
@@ -673,14 +697,42 @@ def chains(self):
673697
)
674698
)
675699

676-
# Get chain/polymer types
677700
chain_type = {}
678-
for entity_id, poly_type in zip(
679-
self.parsed_info["_entity_poly.entity_id"],
680-
self.parsed_info["_entity_poly.type"],
701+
# we check if we have _entity_poly
702+
if (
703+
"_entity_poly.entity_id" in self.parsed_info
704+
and "_entity_poly.type" in self.parsed_info
681705
):
682-
for author_id in id_map[entity_id]:
683-
chain_type[author_id] = poly_type
706+
# get chain/polymer types
707+
for entity_id, poly_type in zip(
708+
self.parsed_info["_entity_poly.entity_id"],
709+
self.parsed_info["_entity_poly.type"],
710+
):
711+
for author_id in id_map[entity_id]:
712+
chain_type[author_id] = poly_type
713+
else:
714+
# if we don't have _entity_poly, we fall back to chem_comp type for each mon_id
715+
# this is for backwards compatibility with older RNA3BD version release mmCIFs
716+
chem_comp_type = {
717+
mon_id: comp_type
718+
for mon_id, comp_type in zip(
719+
self.parsed_info["_chem_comp.id"],
720+
self.parsed_info["_chem_comp.type"],
721+
)
722+
}
723+
for author_id, chain_data in chains_full.items():
724+
# "keep" only chains that contain at least one `self.molecule_type`
725+
if any(
726+
[
727+
self.molecule_type in chem_comp_type[i.three_letter_code]
728+
for i in chain_data.residues
729+
]
730+
):
731+
# if RNA we set to self.polymer_type (i.e. "polyribonucleotide")
732+
chain_type[author_id] = self.polymer_type
733+
else:
734+
# we just set to "other" if not an RNA
735+
chain_type[author_id] = "other"
684736

685737
# keep only chains of the appropriate polymer type
686738
chains = {}
@@ -722,12 +774,6 @@ def chains(self):
722774
)
723775

724776
# make sure that the sites actually match, should never be a mismatch
725-
"""
726-
assert (
727-
site.three_letter_code
728-
== chains[site.author_chain_id][seq_idx].three_letter_code
729-
), f"residue mismatch at chain {site.author_chain_id} pos {seq_idx} (expected {site.three_letter_code}, got {chains[site.author_chain_id][seq_idx].three_letter_code})"
730-
"""
731777
if (
732778
site.three_letter_code
733779
!= chains[site.author_chain_id][seq_idx].three_letter_code

0 commit comments

Comments
 (0)