diff --git a/brahmap/core/noise_ops_toeplitz.py b/brahmap/core/noise_ops_toeplitz.py index 80c90bd..a838abd 100644 --- a/brahmap/core/noise_ops_toeplitz.py +++ b/brahmap/core/noise_ops_toeplitz.py @@ -161,18 +161,25 @@ def __init__( dtype=dtype, ) - self.__precond_atol = precond_atol - self.__precond_maxiter = precond_maxiter - self.__precond_callback = precond_callback + self.precond_atol = precond_atol + self.precond_maxiter = precond_maxiter + self.precond_callback = precond_callback - self.__last_num_iterations = 0 + self.__previous_num_iterations = 0 + + super(InvNoiseCovLO_Toeplitz01, self).__init__( + nargin=size, + matvec=self._mult, + input_type=input_type, + dtype=dtype, + ) if precond_op is None: - self.__precond_op = None + self.precond_op = None elif isinstance(precond_op, LinearOperator) or isinstance( precond_op, np.ndarray ): - self.__precond_op = precond_op + self.precond_op = precond_op elif precond_op in ["Strang", "TChan", "RChan", "KK2"]: if input_type == "power_spectrum": cov = np.fft.ifft(input).real[:size] @@ -204,7 +211,7 @@ def __init__( new_cov[0] = 0 new_cov = cov - new_cov - self.__precond_op = InvNoiseCovLO_Circulant( + self.precond_op = InvNoiseCovLO_Circulant( size=size, input=new_cov, input_type="covariance", @@ -217,29 +224,44 @@ def __init__( message="Invalid preconditioner operator provided!", ) - super(InvNoiseCovLO_Toeplitz01, self).__init__( - nargin=size, - matvec=self._mult, - input_type=input_type, - dtype=dtype, - ) + @property + def precond_op(self): + return self.__precond_op + + @precond_op.setter + def precond_op(self, operator: LinearOperator): + if operator is not None: + MPI_RAISE_EXCEPTION( + condition=(self.shape != operator.shape), + exception=ValueError, + message=f"The shape of the input operator {operator.shape} is not compatible with the shape of inverse Toeplitz operator {self.shape}", + ) + self.__precond_op = operator @property def diag(self) -> np.ndarray: - factor = 1.0 - return factor * np.ones(self.size, dtype=self.dtype) + try: + diag_arr = getattr(self, "__diag") + except AttributeError: + factor = 1.0 + diag_arr = factor * np.ones(self.size, dtype=self.dtype) + return diag_arr + + @diag.setter + def diag(self, diag: np.ndarray): + self.__diag = diag @property - def get_last_num_iterations(self) -> int: - return self.__last_num_iterations + def previous_num_iterations(self) -> int: + return self.__previous_num_iterations def get_inverse(self): return self.__toeplitz_op def __callback(self, x, r, norm_residual): - self.__last_num_iterations += 1 - if self.__precond_callback is not None: - self.__precond_callback(x, r, norm_residual) + self.__previous_num_iterations += 1 + if self.precond_callback is not None: + self.precond_callback(x, r, norm_residual) def _mult(self, vec: np.ndarray): MPI_RAISE_EXCEPTION( @@ -248,7 +270,7 @@ def _mult(self, vec: np.ndarray): message=f"Dimensions of `vec` is not compatible with the dimensions of this `InvNoiseCovLO_Toeplitz` instance.\nShape of `InvNoiseCovLO_Toeplitz` instance: {self.shape}\nShape of `vec`: {vec.shape}", ) - self.__last_num_iterations = 0 + self.__previous_num_iterations = 0 if vec.dtype != self.dtype: if MPI_UTILS.rank == 0: @@ -261,9 +283,9 @@ def _mult(self, vec: np.ndarray): prod, _ = cg( A=self.__toeplitz_op, b=vec, - atol=self.__precond_atol, - maxiter=self.__precond_maxiter, - M=self.__precond_op, + atol=self.precond_atol, + maxiter=self.precond_maxiter, + M=self.precond_op, callback=self.__callback, parallel=False, ) diff --git a/brahmap/core/process_time_samples.py b/brahmap/core/process_time_samples.py index 1ccc285..1f4b8c5 100644 --- a/brahmap/core/process_time_samples.py +++ b/brahmap/core/process_time_samples.py @@ -304,21 +304,14 @@ def bad_pixels(self): def get_hit_counts(self): """Returns hit counts of the pixel indices""" - return self.hit_counts - # hit_counts_newidx = np.zeros(self.new_npix, dtype=int) - # for idx in range(self.nsamples): - # hit_counts_newidx[self.pointings[idx]] += self.pointings_flag[idx] - - # MPI_UTILS.comm.Allreduce(MPI.IN_PLACE, hit_counts_newidx, MPI.SUM) - - # hit_counts = np.ma.masked_array( - # data=np.zeros(self.npix), - # mask=np.logical_not(self.pixel_flag), - # fill_value=-1.6375e30, - # ) + hit_counts = np.ma.masked_array( + data=np.zeros(self.npix), + mask=np.logical_not(self.pixel_flag), + fill_value=-1.6375e30, + ) - # hit_counts[~hit_counts.mask] = hit_counts_newidx - # return hit_counts + hit_counts[~hit_counts.mask] = self.hit_counts + return hit_counts def _compute_weights(self, pol_angles: np.ndarray, noise_weights: np.ndarray): self.hit_counts = np.zeros(self.npix, dtype=self.pointings.dtype)