Skip to content

Commit 5028e71

Browse files
authored
improve chain letter naming (#291)
1 parent 9e7777a commit 5028e71

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

chai_lab/chai1.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
from chai_lab.data.features.generators.token_pair_pocket_restraint import (
8787
TokenPairPocketRestraint,
8888
)
89-
from chai_lab.data.io.cif_utils import save_to_cif
89+
from chai_lab.data.io.cif_utils import get_chain_letter, save_to_cif
9090
from chai_lab.data.parsing.restraints import parse_pairwise_table
9191
from chai_lab.data.parsing.structure.entity_type import EntityType
9292
from chai_lab.model.diffusion_schedules import InferenceNoiseSchedule
@@ -872,7 +872,7 @@ def avg_per_token_1d(x):
872872
inputs["atom_token_index"],
873873
)
874874

875-
ranking_outputs = rank(
875+
ranking_outputs: SampleRanking = rank(
876876
atom_pos[idx : idx + 1],
877877
atom_mask=inputs["atom_exists_mask"],
878878
atom_token_index=inputs["atom_token_index"],
@@ -908,7 +908,8 @@ def avg_per_token_1d(x):
908908
write_path=cif_out_path,
909909
# Set asym names to be A, B, C, ...
910910
asym_entity_names={
911-
i + 1: chr(i + 65) for i in range(len(feature_context.chains))
911+
i: get_chain_letter(i)
912+
for i in range(1, len(feature_context.chains) + 1)
912913
},
913914
)
914915
cif_paths.append(cif_out_path)

chai_lab/data/io/cif_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
4040
_CHAIN_VOCAB = _CHAIN_VOCAB + [x + y for x in _CHAIN_VOCAB for y in _CHAIN_VOCAB]
4141

4242

43-
def _get_chain_letter(asym_id: int) -> str:
43+
def get_chain_letter(asym_id: int) -> str:
44+
"""Get chain given a one-indexed asym_id."""
45+
assert asym_id > 0 and asym_id <= len(_CHAIN_VOCAB)
4446
vocab_index = asym_id - 1 # 1 -> A, 2 -> B
4547
return _CHAIN_VOCAB[vocab_index]
4648

tests/test_cif_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) 2024 Chai Discovery, Inc.
2+
# Licensed under the Apache License, Version 2.0.
3+
# See the LICENSE file for details.
4+
5+
6+
import pytest
7+
8+
from chai_lab.data.io.cif_utils import get_chain_letter
9+
10+
11+
def test_get_chain_letter():
12+
with pytest.raises(AssertionError):
13+
get_chain_letter(0)
14+
assert get_chain_letter(1) == "A"
15+
assert get_chain_letter(26) == "Z"
16+
assert get_chain_letter(27) == "a"
17+
assert get_chain_letter(52) == "z"
18+
19+
assert get_chain_letter(53) == "AA"
20+
assert get_chain_letter(54) == "AB"
21+
22+
# For one-letter codes, there are 26 + 26 = 52 codes
23+
# For two-letter codes, there are 52 * 52 codes
24+
assert get_chain_letter(52 * 52 + 52) == "zz"

0 commit comments

Comments
 (0)