Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing repr #267

Merged
merged 2 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 21 additions & 24 deletions docs/examples/tutorial_pynapple_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,44 +36,41 @@
# Discrete correlograms
# ---------------------
#
# The function to compute cross-correlogram is [*cross_correlogram*](https://peyrachelab.github.io/pynapple/process.correlograms/#pynapple.process.correlograms.cross_correlogram).
# First let's generate some data. Here we have two neurons recorded together. We can group them in a `TsGroup`.
#
#
# The function is compiled with [numba](https://numba.pydata.org/) to improve performances. This means it only accepts pure numpy arrays as input arguments.

ts1 = nap.Ts(t=np.sort(np.random.uniform(0, 1000, 2000)), time_units="s")
ts2 = nap.Ts(t=np.sort(np.random.uniform(0, 1000, 1000)), time_units="s")
epoch = nap.IntervalSet(start=0, end=1000, time_units="s")
ts_group = nap.TsGroup({0: ts1, 1: ts2}, time_support=epoch)

ts1 = nap.Ts(t=np.sort(np.random.uniform(0, 1000, 1000)), time_units="s")
ts2 = nap.Ts(t=np.sort(np.random.uniform(0, 1000, 10)), time_units="s")
print(ts_group)

ts1_time_array = ts1.as_units("s").index.values
ts2_time_array = ts2.as_units("s").index.values
# %%
# First we can compute their autocorrelograms meaning the number of spikes of a neuron observed in a time windows centered around its own spikes.
# For this we can use the function `compute_autocorrelogram`.
# We need to specifiy the `binsize` and `windowsize` to bin the spike train.

binsize = 0.1 # second
cc12, xt = nap.process.correlograms.cross_correlogram(
t1=ts1_time_array, t2=ts2_time_array, binsize=binsize, windowsize=1 # second
autocorrs = nap.compute_autocorrelogram(
group=ts_group, binsize=100, windowsize=1000, time_units="ms", ep=epoch # ms
)

plt.figure(figsize=(10, 6))
plt.bar(xt, cc12, binsize)
plt.xlabel("Time t1 (us)")
plt.ylabel("CC")
print(autocorrs, "\n")

# %%
# To simplify converting to a numpy.ndarray, pynapple provides wrappers for computing autocorrelogram and crosscorrelogram for TsGroup. The function is then called for each unit or each pairs of units. It returns directly a pandas.DataFrame holding all the correlograms. In this example, autocorrelograms and cross-correlograms are computed for the same TsGroup.

epoch = nap.IntervalSet(start=0, end=1000, time_units="s")
ts_group = nap.TsGroup({0: ts1, 1: ts2}, time_support=epoch)
# The variable `autocorrs` is a pandas DataFrame with the center of the bins for the index and each columns is a neuron.
#
# Similarly, we can compute crosscorrelograms meaning how many spikes of neuron 1 do I observe whenever neuron 0 fires. Here the function
# is called `compute_crosscorrelogram` and takes a `TsGroup` as well.

autocorrs = nap.compute_autocorrelogram(
group=ts_group, binsize=100, windowsize=1000, time_units="ms", ep=epoch # ms # ms
)
crosscorrs = nap.compute_crosscorrelogram(
group=ts_group, binsize=100, windowsize=1000, time_units="ms" # ms # ms
group=ts_group, binsize=100, windowsize=1000, time_units="ms" # ms
)

print(autocorrs, "\n")
print(crosscorrs, "\n")

# %%
# Column name (0, 1) is read as cross-correlogram of neuron 0 and 1 with neuron 0 being the reference time.

# %%
# ***
# Peri-Event Time Histogram (PETH)
Expand Down
42 changes: 37 additions & 5 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .config import nap_config
from .time_index import TsIndex
from .utils import (
_get_terminal_size,
_IntervalSetSliceHelper,
_jitfix_iset,
convert_to_numpy,
Expand Down Expand Up @@ -172,11 +173,39 @@ def __repr__(self):
headers = ["start", "end"]
bottom = "shape: {}, time unit: sec.".format(self.shape)

return (
tabulate(self.values, headers=headers, showindex="always", tablefmt="plain")
+ "\n"
+ bottom
)
rows = _get_terminal_size()[1]
max_rows = np.maximum(rows - 10, 6)

if len(self) > max_rows:
n_rows = max_rows // 2
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return (
tabulate(
self.values[0:n_rows],
headers=headers,
showindex=self.index[0:n_rows],
tablefmt="plain",
)
+ "\n\n...\n"
+ tabulate(
self.values[-n_rows:],
headers=[" " * 5, " " * 3], # To align properly the columns
showindex=self.index[-n_rows:],
tablefmt="plain",
)
+ "\n"
+ bottom
)

else:
return (
tabulate(
self.values, headers=headers, showindex="always", tablefmt="plain"
)
+ "\n"
+ bottom
)

def __str__(self):
return self.__repr__()
Expand All @@ -203,6 +232,9 @@ def __getitem__(self, key, *args, **kwargs):
elif isinstance(key, (list, slice, np.ndarray)):
output = self.values.__getitem__(key)
return IntervalSet(start=output[:, 0], end=output[:, 1])
elif isinstance(key, pd.Series):
output = self.values.__getitem__(key)
return IntervalSet(start=output[:, 0], end=output[:, 1])
elif isinstance(key, tuple):
if len(key) == 2:
if isinstance(key[1], Number):
Expand Down
137 changes: 80 additions & 57 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .time_index import TsIndex
from .utils import (
_concatenate_tsd,
_get_terminal_size,
_split_tsd,
_TsdFrameSliceHelper,
convert_to_numpy,
Expand Down Expand Up @@ -620,42 +621,39 @@ def __repr__(self):
headers = ["Time (s)", ""]
bottom = "dtype: {}".format(self.dtype) + ", shape: {}".format(self.shape)

max_rows = 2
rows = _get_terminal_size()[1]
max_rows = np.maximum(rows - 10, 2)

if len(self):

def create_str(array):
if array.ndim == 1:
if len(array) > 2:
return (
"["
+ array[0].__repr__()
+ " ... "
+ array[-1].__repr__()
+ "]"
)
elif len(array) == 2:
return (
"[" + array[0].__repr__() + "," + array[1].__repr__() + "]"
return np.array2string(
np.array([array[0], array[-1]]),
precision=6,
separator=" ... ",
)
elif len(array) == 1:
return "[" + array[0].__repr__() + "]"
else:
return "[]"
return np.array2string(array, precision=6, separator=", ")
else:
return "[" + create_str(array[0]) + " ...]"

_str_ = []
if self.shape[0] < 100:
for i, array in zip(self.index, self.values):
_str_.append([i.__repr__(), create_str(array)])
else:
for i, array in zip(self.index[0:5], self.values[0:5]):
if self.shape[0] > max_rows:
n_rows = max_rows // 2
for i, array in zip(self.index[0:n_rows], self.values[0:n_rows]):
_str_.append([i.__repr__(), create_str(array)])
_str_.append(["...", ""])
for i, array in zip(
self.index[-5:],
self.values[self.values.shape[0] - 5 : self.values.shape[0]],
self.index[-n_rows:],
self.values[self.values.shape[0] - n_rows : self.values.shape[0]],
):
_str_.append([i.__repr__(), create_str(array)])
else:
for i, array in zip(self.index, self.values):
_str_.append([i.__repr__(), create_str(array)])

return tabulate(_str_, headers=headers, colalign=("left",)) + "\n" + bottom

Expand Down Expand Up @@ -818,40 +816,52 @@ def __repr__(self):
headers = ["Time (s)"] + [str(k) for k in self.columns]
bottom = "dtype: {}".format(self.dtype) + ", shape: {}".format(self.shape)

max_cols = 5
try:
max_cols = os.get_terminal_size()[0] // 16
except Exception:
import shutil

max_cols = shutil.get_terminal_size().columns // 16
else:
pass
cols, rows = _get_terminal_size()
max_cols = np.maximum(cols // 100, 5)
max_rows = np.maximum(rows - 10, 2)

if self.shape[1] > max_cols:
headers = headers[0 : max_cols + 1] + ["..."]

def round_if_float(x):
if isinstance(x, float):
return np.round(x, 5)
else:
return x

with warnings.catch_warnings():
warnings.simplefilter("ignore")
if len(self):
table = []
end = ["..."] if self.shape[1] > max_cols else []
if len(self) > 51:
for i, array in zip(self.index[0:5], self.values[0:5, 0:max_cols]):
table.append([i] + [k for k in array] + end)
if len(self) > max_rows:
n_rows = max_rows // 2
for i, array in zip(
self.index[0:n_rows], self.values[0:n_rows, 0:max_cols]
):
table.append([i] + [round_if_float(k) for k in array] + end)
table.append(["..."])
for i, array in zip(
self.index[-5:],
self.index[-n_rows:],
self.values[
self.values.shape[0] - 5 : self.values.shape[0], 0:max_cols
self.values.shape[0] - n_rows : self.values.shape[0],
0:max_cols,
],
):
table.append([i] + [k for k in array] + end)
return tabulate(table, headers=headers) + "\n" + bottom
table.append([i] + [round_if_float(k) for k in array] + end)
return (
tabulate(table, headers=headers, colalign=("left",))
+ "\n"
+ bottom
)
else:
for i, array in zip(self.index, self.values[:, 0:max_cols]):
table.append([i] + [k for k in array] + end)
return tabulate(table, headers=headers) + "\n" + bottom
table.append([i] + [round_if_float(k) for k in array] + end)
return (
tabulate(table, headers=headers, colalign=("left",))
+ "\n"
+ bottom
)
else:
return tabulate([], headers=headers) + "\n" + bottom

Expand Down Expand Up @@ -1053,27 +1063,24 @@ def __repr__(self):
headers = ["Time (s)", ""]
bottom = "dtype: {}".format(self.dtype) + ", shape: {}".format(self.shape)

max_rows = 2
rows = _get_terminal_size()[1]
max_rows = np.maximum(rows - 10, 2)

with warnings.catch_warnings():
warnings.simplefilter("ignore")
if len(self):
if len(self) < 51:
return (
tabulate(
np.vstack((self.index, self.values)).T,
headers=headers,
colalign=("left",),
)
+ "\n"
+ bottom
)
else:
if len(self) > max_rows:
n_rows = max_rows // 2
table = []
for i, v in zip(self.index[0:5], self.values[0:5]):
for i, v in zip(self.index[0:n_rows], self.values[0:n_rows]):
table.append([i, v])
table.append(["..."])
for i, v in zip(
self.index[-5:],
self.values[self.values.shape[0] - 5 : self.values.shape[0]],
self.index[-n_rows:],
self.values[
self.values.shape[0] - n_rows : self.values.shape[0]
],
):
table.append([i, v])

Expand All @@ -1082,6 +1089,16 @@ def __repr__(self):
+ "\n"
+ bottom
)
else:
return (
tabulate(
np.vstack((self.index, self.values)).T,
headers=headers,
colalign=("left",),
)
+ "\n"
+ bottom
)
else:
return tabulate([], headers=headers) + "\n" + bottom

Expand Down Expand Up @@ -1357,14 +1374,20 @@ def __init__(self, t, time_units="s", time_support=None):

def __repr__(self):
upper = "Time (s)"
if len(self) < 50:
_str_ = "\n".join([i.__repr__() for i in self.index])
else:

max_rows = 2
rows = _get_terminal_size()[1]
max_rows = np.maximum(rows - 10, 2)

if len(self) > max_rows:
n_rows = max_rows // 2
_str_ = "\n".join(
[i.__repr__() for i in self.index[0:5]]
[i.__repr__() for i in self.index[0:n_rows]]
+ ["..."]
+ [i.__repr__() for i in self.index[-5:]]
+ [i.__repr__() for i in self.index[-n_rows:]]
)
else:
_str_ = "\n".join([i.__repr__() for i in self.index])

bottom = "shape: {}".format(len(self.index))
return "\n".join((upper, _str_, bottom))
Expand Down
Loading
Loading