diff --git a/pxmeter/input_builder/model_inputs/boltz.py b/pxmeter/input_builder/model_inputs/boltz.py index 56e185b..ad8a62c 100644 --- a/pxmeter/input_builder/model_inputs/boltz.py +++ b/pxmeter/input_builder/model_inputs/boltz.py @@ -17,8 +17,9 @@ from dataclasses import dataclass from pathlib import Path -import yaml +import gemmi +import yaml from pxmeter.constants import DNA, LIGAND, PROTEIN, RNA from pxmeter.input_builder.seq import ( Bond, @@ -29,6 +30,34 @@ from pxmeter.utils import int_to_letters +def gemmi_replace_seq_3to1(seq: PolymerChainSequence) -> PolymerChainSequence: + """ + Replace pxmeter 3 to 1 letter mapping to gemmi 3 to 1 letter mapping. + + Args: + seq: PolymerChainSequence object. + + Returns: + PolymerChainSequence object with replaced sequence. + """ + new_sequece = [i for i in seq.sequence] + for mod_pos, mod_res_name in seq.modifications: + old = new_sequece[mod_pos - 1] + new = gemmi.one_letter_code([mod_res_name])[0] + if old != new: + logging.warning( + f"Change {mod_res_name}->{old} to {mod_res_name}->{new} for consistent with boltz (use gemmi.one_letter_code)." + ) + new_sequece[mod_pos - 1] = new + return PolymerChainSequence( + entity_type=seq.entity_type, + sequence="".join(new_sequece), + modifications=seq.modifications, + ori_entity_id=seq.ori_entity_id, + ori_chain_id=seq.ori_chain_id, + ) + + @dataclass(kw_only=True) class BoltzInput: """ @@ -111,9 +140,16 @@ def from_sequences(cls, sequences: Sequences): Returns: BoltzInput: Constructed BoltzInput object. """ + new_seqs = [] + for seq in sequences.sequences: + if seq.is_polymer(): + new_seqs.append(gemmi_replace_seq_3to1(seq)) + else: + new_seqs.append(seq) + return cls( name=sequences.name, - sequences=sequences.sequences, + sequences=new_seqs, bonds=sequences.bonds, )