diff --git a/hmms/cthmm.pyx b/hmms/cthmm.pyx index 06667b9..c2c4559 100755 --- a/hmms/cthmm.pyx +++ b/hmms/cthmm.pyx @@ -739,7 +739,7 @@ cdef class CtHMM(hmm.HMM): for it in range( iterations ): - print("it",it) + # print("it",it) self._prepare_matrices_pt( t_seqs ) @@ -814,23 +814,18 @@ cdef class CtHMM(hmm.HMM): if i == j: - tA /= self._pt[ ix ] + tA = numpy.divide(tA, self._pt[ ix ], out=numpy.zeros_like(tA), where=tA!=0) tau[i] += numpy.exp( self.log_sum( (ksi_sum[ix] + numpy.log( tA ) ).flatten() ) ) #tau is not in log prob anymore. else: tA *= self._q[i,j] - tA /= self._pt[ ix ] + tA = numpy.divide(tA, self._pt[ ix ], out=numpy.zeros_like(tA), where=tA!=0) eta[i,j] += numpy.exp( self.log_sum( (ksi_sum[ix] + numpy.log( tA ) ).flatten() ) ) #eta is not in log prob anymore. #Update parameter Q: - self._q = ( eta.T / tau ).T self._q = numpy.nan_to_num(self._q) # nan can appear, when some of the states is not reachable - - if sum( self._q.flatten() ) == 0: - raise ValueError("Parameter error! Matrix Q can't contain unreachable states.") - for i in range( s_num ): self._q[i,i] = - numpy.sum( self._q[i,:] ) @@ -934,8 +929,7 @@ cdef class CtHMM(hmm.HMM): for it in range( iterations ): - print("iteration ", i+1, "/", iterations ) - + # print("iteration ", it+1, "/", iterations ) self._prepare_matrices_pt( times ) @@ -1033,14 +1027,15 @@ cdef class CtHMM(hmm.HMM): if i == j: - tA /= self._pt[ ix ] + tA = numpy.divide(tA, self._pt[ ix ], out=numpy.zeros_like(tA), where=tA!=0) tau[i] += numpy.exp( self.log_sum( (ksi_sum[ix] + numpy.log( tA ) ).flatten() ) ) #tau is not in log prob anymore. else: tA *= self._q[i,j] - tA /= self._pt[ ix ] + tA = numpy.divide(tA, self._pt[ ix ], out=numpy.zeros_like(tA), where=tA!=0) + eta[i,j] += numpy.exp( self.log_sum( (ksi_sum[ix] + numpy.log( tA ) ).flatten() ) ) #eta is not in log prob anymore. #Update parameters: @@ -1052,13 +1047,8 @@ cdef class CtHMM(hmm.HMM): #jump rates matrice self._q = ( eta.T / tau ).T - - self._q = numpy.nan_to_num(self._q) # nan can appear, when some of the states is not reachable - if sum( self._q.flatten() ) == 0: - raise ValueError("Parameter error! Matrix Q can't contain unreachable states.") - for i in range( s_num ): self._q[i,i] = - numpy.sum( self._q[i,:] )