Skip to content

Commit

Permalink
pyright fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Nov 8, 2023
1 parent 13975a2 commit cc1ae50
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pyright_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
- name: Install
run: pip install .
- name: Install pyright
run: pip install pyright
run: pip install --upgrade pyright
- name: Install other
run: pip install pyvips
- name: Run pyright
Expand Down
2 changes: 1 addition & 1 deletion examples/example_average_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def compute_average_waveform(*, recording: si.BaseRecording, sorting: si.BaseSor
snippets = extract_snippets(traces=traces, times=times, snippet_len=(20, 20))
waveform = np.mean(snippets, axis=0).T.astype(np.float32)
stdev = np.std(snippets, axis=0).T.astype(np.float32)
waveform_percentiles = np.percentile(snippets, [5, 25, 75, 95], axis=0)
waveform_percentiles = np.percentile(snippets, [5, 25, 75, 95], axis=0) # type: ignore
waveform_percentiles = [waveform_percentiles[i].T.astype(np.float32) for i in range(4)]
return {"channel_ids": recording.get_channel_ids().astype(np.int32), "waveform": waveform, "waveform_std_dev": stdev, "waveform_percentiles": waveform_percentiles}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ def prepare_spikesortingview_data(
start_frame_with_padding = max(start_frame - snippet_len[0], 0)
end_frame_with_padding = min(end_frame + snippet_len[1], num_frames)
traces_with_padding = recording.get_traces(start_frame=start_frame_with_padding, end_frame=end_frame_with_padding)
assert isinstance(traces_with_padding, np.ndarray)
for unit_id in unit_ids:
if str(unit_id) not in unit_peak_channel_ids:
spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=start_frame, end_frame=end_frame)
assert isinstance(spike_train, np.ndarray)
if len(spike_train) > 0:
values = traces_with_padding[spike_train - start_frame_with_padding, :]
values = traces_with_padding[spike_train.astype(np.int32) - start_frame_with_padding, :]
avg_value = np.mean(values, axis=0)
peak_channel_ind = np.argmax(np.abs(avg_value))
peak_channel_id = channel_ids[peak_channel_ind]
Expand Down

0 comments on commit cc1ae50

Please sign in to comment.