Skip to content

Commit

Permalink
Merge pull request #104 from neherlab/fix_average_rate
Browse files Browse the repository at this point in the history
gtr: exclude gap state from average rate calculation
  • Loading branch information
rneher authored Feb 9, 2020
2 parents c1d2e17 + 8b1adb9 commit 1f8bb25
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions treetime/gtr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from treetime import config as ttconf
from .seq_utils import alphabets, profile_maps, alphabet_synonyms

def avg_transition(W,pi, gap_index=None):
if gap_index is None:
return np.einsum('i,ij,j', pi, W, pi)
else:
return (np.einsum('i,ij,j', pi, W, pi) - np.sum(pi*W[:,gap_index])*pi[gap_index])/(1-pi[gap_index])


class GTR(object):
"""
Expand Down Expand Up @@ -212,7 +218,7 @@ def assign_rates(self, mu=1.0, pi=None, W=None):

self._W = 0.5*(W+W.T)
np.fill_diagonal(W,0)
average_rate = W.dot(self.Pi).dot(self.Pi)
average_rate = avg_transition(W, self.Pi, gap_index=self.gap_index)
self._W = W/average_rate
self._mu *=average_rate

Expand Down Expand Up @@ -378,23 +384,27 @@ def standard(model, **kwargs):
from .aa_models import JTT92

if model.lower() in ['jc', 'jc69', 'jukes-cantor', 'jukes-cantor69', 'jukescantor', 'jukescantor69']:
return JC69(**kwargs)
model = JC69(**kwargs)
elif model.lower() in ['k80', 'kimura80', 'kimura1980']:
return K80(**kwargs)
model = K80(**kwargs)
elif model.lower() in ['f81', 'felsenstein81', 'felsenstein1981']:
return F81(**kwargs)
model = F81(**kwargs)
elif model.lower() in ['hky', 'hky85', 'hky1985']:
return HKY85(**kwargs)
model = HKY85(**kwargs)
elif model.lower() in ['t92', 'tamura92', 'tamura1992']:
return T92(**kwargs)
model = T92(**kwargs)
elif model.lower() in ['tn93', 'tamura_nei_93', 'tamuranei93']:
return TN93(**kwargs)
model = TN93(**kwargs)
elif model.lower() in ['jtt', 'jtt92']:
return JTT92(**kwargs)
model = JTT92(**kwargs)
else:
raise KeyError("The GTR model '{}' is not in the list of available models."
"".format(model))

model.mu = kwargs['mu'] if 'mu' in kwargs else 1.0
return model


@classmethod
def random(cls, mu=1.0, alphabet='nuc'):
"""
Expand Down Expand Up @@ -494,7 +504,7 @@ def infer(cls, nij, Ti, root_state, fixed_pi=None, pc=1.0, gap_limit=0.01, **kwa
+ ttconf.TINY_NUMBER + 2*pc_mat)

np.fill_diagonal(W_ij, 0)
scale_factor = np.einsum('i,ij,j',pi,W_ij,pi)
scale_factor = avg_transition(W_ij,pi, gap_index=gtr.gap_index)

W_ij = W_ij/scale_factor
if fixed_pi is None:
Expand Down Expand Up @@ -1002,7 +1012,7 @@ def sequence_logLH(self,seq, pattern_multiplicity=None):
for si,state in enumerate(self.alphabet)])

def average_rate(self):
return -self.mu*np.einsum('ii,i',self.Q, self.Pi)
return self.mu*avg_transition(self.W, self.Pi, gap_index=self.gap_index)

def save_to_npz(self, outfile):
full_gtr = self.mu * np.dot(self.Pi, self.W)
Expand Down

0 comments on commit 1f8bb25

Please sign in to comment.