From 768b011e10b9fdc02f3443f122753d28ee0a9b44 Mon Sep 17 00:00:00 2001 From: minhuanli Date: Tue, 30 Apr 2024 15:39:00 -0400 Subject: [PATCH] include some methods to go between orth and frac --- SFC_Torch/Fmodel.py | 28 +++++++++++++++++++++++++++- SFC_Torch/io.py | 41 +++++++++++++++++++++++++++++++++++++---- SFC_Torch/symmetry.py | 6 ++++-- 3 files changed, 68 insertions(+), 7 deletions(-) diff --git a/SFC_Torch/Fmodel.py b/SFC_Torch/Fmodel.py index f1adfbb..52ff042 100644 --- a/SFC_Torch/Fmodel.py +++ b/SFC_Torch/Fmodel.py @@ -238,7 +238,7 @@ def atom_pos_frac(self): """ Tensor of atom's Positions in fractional space, [Nc,3] """ - return torch.tensordot(self.atom_pos_orth, self.orth2frac_tensor.T, 1) + return self.orth2frac(self.atom_pos_orth) @property def cra_name(self): @@ -262,6 +262,32 @@ def n_atoms(self): def unique_atom(self): return list(set(self.atom_name)) + def frac2orth(self, frac_pos: torch.Tensor) -> torch.Tensor: + """ + Convert fractional coordinates to orthogonal coordinates + + Args: + frac_pos: torch.Tensor, [n_points, ..., 3] + + Returns: + orthogonal coordinates, torch.Tensor + """ + orth_pos = torch.einsum("n...x,yx->n...y", frac_pos, self.frac2orth_tensor) + return orth_pos + + def orth2frac(self, orth_pos: torch.Tensor) -> torch.Tensor: + """ + Convert orthogonal coordinates to fractional coordinates + + Args: + orth_pos: torch.Tensor, [n_points, ..., 3] + + Returns: + fractional coordinates, torch.Tensor + """ + frac_pos = torch.einsum("n...x,yx->n...y", orth_pos, self.orth2frac_tensor) + return frac_pos + def init_mtz(self, mtzdata, N_bins, expcolumns, set_experiment, freeflag, testset_value): """ set mtz file for HKL list, resolution and experimental related properties diff --git a/SFC_Torch/io.py b/SFC_Torch/io.py index 979dfa6..6a88c45 100644 --- a/SFC_Torch/io.py +++ b/SFC_Torch/io.py @@ -1,7 +1,7 @@ import gemmi -import torch -import numpy as np import urllib.request, os + +import numpy as np from tqdm import tqdm import pandas as pd @@ -151,8 +151,6 @@ def sequence(self): sequence = "".join([gemmi.find_tabulated_residue(r).one_letter_code for r in sequence]) return sequence - - def to_gemmi(self, include_header=True): """ Convert the array data to gemmi.Structure @@ -178,6 +176,10 @@ def to_gemmi(self, include_header=True): # Next time user parse the new pdb with gemmi, will have all the info again st.raw_remarks = self.pdb_header return st + + @property + def atom_pos_frac(self): + return self.orth2frac(self.atom_pos) def set_spacegroup(self, spacegroup): """ @@ -290,11 +292,42 @@ def from_atom_slices(self, atom_slices, inplace=False): return new_parser def move2cell(self): + """ + move the current model into the cell by shifting + """ frac_mat = np.array(self.cell.fractionalization_matrix.tolist()) mean_positions_frac = np.dot(frac_mat, np.mean(assert_numpy(self.atom_pos), axis=0)) shift_vec = np.dot(np.linalg.inv(frac_mat), mean_positions_frac % 1.0 - mean_positions_frac) self.set_positions(assert_numpy(self.atom_pos) + shift_vec) + def orth2frac(self, orth_pos: np.ndarray) -> np.ndarray: + """ + Convert orthogonal coordinates to fractional coordinates + + Args: + orth_pos: np.ndarray, [n_points, ..., 3] + + Returns: + frational coordinates, np.ndarray, [n_points, ..., 3] + """ + orth2frac_mat = np.array(self.unit_cell.fractionalization_matrix.tolist()) + frac_pos = np.einsum("n...x,yx->n...y", orth_pos, orth2frac_mat) + return frac_pos + + def frac2orth(self, frac_pos: np.ndarray) -> np.ndarray: + """ + Convert fractional coordinates to orthogonal coordinates + + Args: + frac_pos: np.ndarray, [n_points, ..., 3] + + Returns: + orthogonal coordinates, np.ndarray, [n_points, ..., 3] + """ + frac2orth_mat = np.array(self.unit_cell.orthogonalization_matrix.tolist()) + orth_pos = np.einsum("n...x,yx->n...y", frac_pos, frac2orth_mat) + return orth_pos + def savePDB(self, savefilename, include_header=True): structure = self.to_gemmi(include_header=include_header) structure.write_pdb(savefilename) diff --git a/SFC_Torch/symmetry.py b/SFC_Torch/symmetry.py index 4846814..4121cb2 100644 --- a/SFC_Torch/symmetry.py +++ b/SFC_Torch/symmetry.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import numpy as np import gemmi import torch import reciprocalspaceship as rs import pandas as pd -from typing import Optional, List, Union +from typing import List ccp4_hkl_asu = [ 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, @@ -275,7 +277,7 @@ def asu2p1_torch(atom_pos_orth, unitcell, spacegroup, return sym_oped_pos_orth -def get_polar_axis(spacegroup : gemmi.SpaceGroup) -> Optional[List[int]]: +def get_polar_axis(spacegroup : gemmi.SpaceGroup) -> List[int] | None: """ Return list of polar axis of a spacegroup