diff --git a/src/np_utils/cross_corr.py b/src/np_utils/cross_corr.py index cc6274a..607fe12 100644 --- a/src/np_utils/cross_corr.py +++ b/src/np_utils/cross_corr.py @@ -554,18 +554,18 @@ def plot_cc_grid( n_cols (int, optional): Number of columns in the grid. Defaults to 4. avoid_symmetry (bool, optional): If True, only the upper triangle of the grid is plotted. Defaults to True. """ - if channels_i is not None: - channels_i = sorted(channels_i) - if channels_j is not None: - channels_j = sorted(channels_j) if not isinstance(cross_corrs, np.ndarray) or len(cross_corrs.shape) != 3: raise ValueError( "cross_corrs should be a 3D numnpy array of size : n_sources x n_targets x n_bins" ) - sources = range(cross_corrs.shape[0]) if channels_i is None else channels_i - targets = range(cross_corrs.shape[1]) if channels_j is None else channels_j + channels_i = range(cross_corrs.shape[0]) if channels_i is None else sorted(channels_i) + channels_j = range(cross_corrs.shape[1]) if channels_j is None else sorted(channels_j) + + sources = range(cross_corrs.shape[0]) + targets = range(cross_corrs.shape[1]) + for i in tqdm(sources, desc="Plotting cross-correlograms..."): fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(20, 40)) axes = axes.flatten() @@ -578,7 +578,7 @@ def plot_cc_grid( if cc.sum() == 0: continue ax = axes[tot] - self.plot(cc, from_=str(i), to_=str(j), ax=ax) + self.plot(cc, from_=str(channels_i[i]), to_=str(channels_j[j]), ax=ax) except KeyboardInterrupt: return except Exception as e: