Skip to content

Commit

Permalink
Merge pull request #250 from AxFoundation/protect_wraparound
Browse files Browse the repository at this point in the history
Simplify length computations / protect from wraparounds
  • Loading branch information
JelleAalbers authored Apr 6, 2020
2 parents 8aa6644 + 72cc731 commit b58d744
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 37 deletions.
27 changes: 1 addition & 26 deletions strax/processing/data_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def cut_baseline(records, n_before=48, n_after=30):
d.data[:n_before] = 0

clear_from = d.pulse_length - n_after
clear_from -= d.record_i * samples_per_record
clear_from -= d.record_i.astype(np.int32) * samples_per_record
clear_from = max(0, clear_from)
if clear_from < samples_per_record:
d.data[clear_from:] = 0
Expand Down Expand Up @@ -120,28 +120,3 @@ def _cut_outside_hits(records, hits, new_recs,
b_next = end_keep - samples_per_record
new_recs[next_ri]['data'][:b_next] = \
records[next_ri]['data'][:b_next]


@export
@numba.jit(nopython=True, nogil=True, cache=True)
def replace_with_spike(records, also_for_multirecord_pulses=False):
"""Replaces the waveform in each record with a spike of the same integral
:param also_for_multirecord_pulses: if True, does this even if the pulse
spans multiple records (so you'll get more than one spike...)
"""
if not len(records):
return
samples_per_record = len(records[0]['data'])

for i, d in enumerate(records):
if not (d.record_i == 0 or also_for_multirecord_pulses):
continue
# What is the center of this record? It's nontrivial since
# some records have parts that do not represent data at the end
center = int(min(d.total_length - samples_per_record * d.record_i,
samples_per_record) // 2)
integral = d.data.sum()
d.data[:] = 0
d.data[center] = integral

records.reduction_level[:] = ReductionLevel.WAVEFORM_REPLACED
16 changes: 5 additions & 11 deletions strax/processing/pulse_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ def baseline(records, baseline_samples=40, flip=True,

# Subtract baseline from all data samples in the record
# (any additional zeros should be kept at zero)
last = min(samples_per_record,
d['pulse_length'] - d['record_i'] * samples_per_record)
d['data'][:last] = (-1 * flip) * (d['data'][:last] - int(bl))
d['data'][:d['length']] = (
(-1 * flip) * (d['data'][:d['length']] - int(bl)))
d['baseline'] = bl
d['baseline_rms'] = rms

Expand Down Expand Up @@ -114,9 +113,8 @@ def zero_out_of_bounds(records):
samples_per_record = len(records[0]['data'])

for r in records:
end = r['pulse_length'] - r['record_i'] * samples_per_record
if end < samples_per_record:
r['data'][end:] = 0
if r['length'] < samples_per_record:
r['data'][r['length']:] = 0


@export
Expand All @@ -125,16 +123,12 @@ def integrate(records):
"""Integrate records in-place"""
if not len(records):
return
samples_per_record = len(records[0]['data'])
for i, r in enumerate(records):
n_real_samples = min(
samples_per_record,
r['pulse_length'] - r['record_i'] * samples_per_record)
records[i]['area'] = (
r['data'].sum()
# Add floating part of baseline * number of samples
# int(round()) the result since the area field is an int
+ int(round((r['baseline'] % 1) * n_real_samples)))
+ int(round((r['baseline'] % 1) * r['length'])))


@export
Expand Down

0 comments on commit b58d744

Please sign in to comment.