Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve chain letter naming #291

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
from chai_lab.data.features.generators.token_pair_pocket_restraint import (
TokenPairPocketRestraint,
)
from chai_lab.data.io.cif_utils import save_to_cif
from chai_lab.data.io.cif_utils import get_chain_letter, save_to_cif
from chai_lab.data.parsing.restraints import parse_pairwise_table
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.model.diffusion_schedules import InferenceNoiseSchedule
Expand Down Expand Up @@ -872,7 +872,7 @@ def avg_per_token_1d(x):
inputs["atom_token_index"],
)

ranking_outputs = rank(
ranking_outputs: SampleRanking = rank(
atom_pos[idx : idx + 1],
atom_mask=inputs["atom_exists_mask"],
atom_token_index=inputs["atom_token_index"],
Expand Down Expand Up @@ -908,7 +908,8 @@ def avg_per_token_1d(x):
write_path=cif_out_path,
# Set asym names to be A, B, C, ...
asym_entity_names={
i + 1: chr(i + 65) for i in range(len(feature_context.chains))
i: get_chain_letter(i)
for i in range(1, len(feature_context.chains) + 1)
},
)
cif_paths.append(cif_out_path)
Expand Down
4 changes: 3 additions & 1 deletion chai_lab/data/io/cif_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
_CHAIN_VOCAB = _CHAIN_VOCAB + [x + y for x in _CHAIN_VOCAB for y in _CHAIN_VOCAB]


def _get_chain_letter(asym_id: int) -> str:
def get_chain_letter(asym_id: int) -> str:
"""Get chain given a one-indexed asym_id."""
assert asym_id > 0 and asym_id <= len(_CHAIN_VOCAB)
vocab_index = asym_id - 1 # 1 -> A, 2 -> B
return _CHAIN_VOCAB[vocab_index]

Expand Down
24 changes: 24 additions & 0 deletions tests/test_cif_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2024 Chai Discovery, Inc.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.


import pytest

from chai_lab.data.io.cif_utils import get_chain_letter


def test_get_chain_letter():
with pytest.raises(AssertionError):
get_chain_letter(0)
assert get_chain_letter(1) == "A"
assert get_chain_letter(26) == "Z"
assert get_chain_letter(27) == "a"
assert get_chain_letter(52) == "z"

assert get_chain_letter(53) == "AA"
assert get_chain_letter(54) == "AB"

# For one-letter codes, there are 26 + 26 = 52 codes
# For two-letter codes, there are 52 * 52 codes
assert get_chain_letter(52 * 52 + 52) == "zz"
Loading