Skip to content

Commit

Permalink
new generic interface for spike_sync functions
Browse files Browse the repository at this point in the history
Similar to the isi and spike distance functions, also the spike sync functions
now support the new generic interface.
  • Loading branch information
mariomulansky committed Feb 2, 2016
1 parent ea3709e commit a57f3d5
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 36 deletions.
136 changes: 101 additions & 35 deletions pyspike/spike_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,40 @@
############################################################
# spike_sync_profile
############################################################
def spike_sync_profile(spike_train1, spike_train2, max_tau=None):
def spike_sync_profile(*args, **kwargs):
""" Computes the spike-synchronization profile S_sync(t) of the given
spike trains. Returns the profile as a DiscreteFunction object. In the
bivariate case, he S_sync values are either 1 or 0, indicating the presence
or absence of a coincidence. For multi-variate cases, each spike in the set
of spike trains, the profile is defined as the number of coincidences
divided by the number of spike trains pairs involving the spike train of
containing this spike, which is the number of spike trains minus one (N-1).
Valid call structures::
spike_sync_profile(st1, st2) # returns the bi-variate profile
spike_sync_profile(st1, st2, st3) # multi-variate profile of 3 sts
sts = [st1, st2, st3, st4] # list of spike trains
spike_sync_profile(sts) # profile of the list of spike trains
spike_sync_profile(sts, indices=[0, 1]) # use only the spike trains
# given by the indices
:returns: The spike-sync profile :math:`S_{sync}(t)`.
:rtype: :class:`pyspike.function.DiscreteFunction`
"""
if len(args) == 1:
return spike_sync_profile_multi(args[0], **kwargs)
elif len(args) == 2:
return spike_sync_profile_bi(args[0], args[1])
else:
return spike_sync_profile_multi(args)


############################################################
# spike_sync_profile_bi
############################################################
def spike_sync_profile_bi(spike_train1, spike_train2, max_tau=None):
""" Computes the spike-synchronization profile S_sync(t) of the two given
spike trains. Returns the profile as a DiscreteFunction object. The S_sync
values are either 1 or 0, indicating the presence or absence of a
Expand All @@ -27,7 +60,7 @@ def spike_sync_profile(spike_train1, spike_train2, max_tau=None):
:type spike_train2: :class:`pyspike.SpikeTrain`
:param max_tau: Maximum coincidence window size. If 0 or `None`, the
coincidence window has no upper bound.
:returns: The spike-distance profile :math:`S_{sync}(t)`.
:returns: The spike-sync profile :math:`S_{sync}(t)`.
:rtype: :class:`pyspike.function.DiscreteFunction`
"""
Expand Down Expand Up @@ -61,6 +94,33 @@ def spike_sync_profile(spike_train1, spike_train2, max_tau=None):
return DiscreteFunc(times, coincidences, multiplicity)


############################################################
# spike_sync_profile_multi
############################################################
def spike_sync_profile_multi(spike_trains, indices=None, max_tau=None):
""" Computes the multi-variate spike synchronization profile for a set of
spike trains. For each spike in the set of spike trains, the multi-variate
profile is defined as the number of coincidences divided by the number of
spike trains pairs involving the spike train of containing this spike,
which is the number of spike trains minus one (N-1).
:param spike_trains: list of :class:`pyspike.SpikeTrain`
:param indices: list of indices defining which spike trains to use,
if None all given spike trains are used (default=None)
:type indices: list or None
:param max_tau: Maximum coincidence window size. If 0 or `None`, the
coincidence window has no upper bound.
:returns: The multi-variate spike sync profile :math:`<S_{sync}>(t)`
:rtype: :class:`pyspike.function.DiscreteFunction`
"""
prof_func = partial(spike_sync_profile_bi, max_tau=max_tau)
average_prof, M = _generic_profile_multi(spike_trains, prof_func,
indices)
# average_dist.mul_scalar(1.0/M) # no normalization here!
return average_prof


############################################################
# _spike_sync_values
############################################################
Expand All @@ -87,18 +147,51 @@ def _spike_sync_values(spike_train1, spike_train2, interval, max_tau):
return c, mp
except ImportError:
# Cython backend not available: fall back to profile averaging
return spike_sync_profile(spike_train1, spike_train2,
max_tau).integral(interval)
return spike_sync_profile_bi(spike_train1, spike_train2,
max_tau).integral(interval)
else:
# some specific interval is provided: use profile
return spike_sync_profile(spike_train1, spike_train2,
max_tau).integral(interval)
return spike_sync_profile_bi(spike_train1, spike_train2,
max_tau).integral(interval)


############################################################
# spike_sync
############################################################
def spike_sync(spike_train1, spike_train2, interval=None, max_tau=None):
def spike_sync(*args, **kwargs):
""" Computes the spike synchronization value SYNC of the given spike
trains. The spike synchronization value is the computed as the total number
of coincidences divided by the total number of spikes:
.. math:: SYNC = \sum_n C_n / N.
Valid call structures::
spike_sync(st1, st2) # returns the bi-variate spike synchronization
spike_sync(st1, st2, st3) # multi-variate result for 3 spike trains
spike_trains = [st1, st2, st3, st4] # list of spike trains
spike_sync(spike_trains) # spike-sync of the list of spike trains
spike_sync(spike_trains, indices=[0, 1]) # use only the spike trains
# given by the indices
:returns: The spike synchronization value.
:rtype: `double`
"""

if len(args) == 1:
return spike_sync_multi(args[0], **kwargs)
elif len(args) == 2:
return spike_sync_bi(args[0], args[1], **kwargs)
else:
return spike_sync_multi(args, **kwargs)


############################################################
# spike_sync_bi
############################################################
def spike_sync_bi(spike_train1, spike_train2, interval=None, max_tau=None):
""" Computes the spike synchronization value SYNC of the given spike
trains. The spike synchronization value is the computed as the total number
of coincidences divided by the total number of spikes:
Expand All @@ -122,33 +215,6 @@ def spike_sync(spike_train1, spike_train2, interval=None, max_tau=None):
return 1.0*c/mp


############################################################
# spike_sync_profile_multi
############################################################
def spike_sync_profile_multi(spike_trains, indices=None, max_tau=None):
""" Computes the multi-variate spike synchronization profile for a set of
spike trains. For each spike in the set of spike trains, the multi-variate
profile is defined as the number of coincidences divided by the number of
spike trains pairs involving the spike train of containing this spike,
which is the number of spike trains minus one (N-1).
:param spike_trains: list of :class:`pyspike.SpikeTrain`
:param indices: list of indices defining which spike trains to use,
if None all given spike trains are used (default=None)
:type indices: list or None
:param max_tau: Maximum coincidence window size. If 0 or `None`, the
coincidence window has no upper bound.
:returns: The multi-variate spike sync profile :math:`<S_{sync}>(t)`
:rtype: :class:`pyspike.function.DiscreteFunction`
"""
prof_func = partial(spike_sync_profile, max_tau=max_tau)
average_prof, M = _generic_profile_multi(spike_trains, prof_func,
indices)
# average_dist.mul_scalar(1.0/M) # no normalization here!
return average_prof


############################################################
# spike_sync_multi
############################################################
Expand Down Expand Up @@ -211,6 +277,6 @@ def spike_sync_matrix(spike_trains, indices=None, interval=None, max_tau=None):
:rtype: np.array
"""
dist_func = partial(spike_sync, max_tau=max_tau)
dist_func = partial(spike_sync_bi, max_tau=max_tau)
return _generic_distance_matrix(spike_trains, dist_func,
indices, interval)
14 changes: 13 additions & 1 deletion test/test_generic_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" test_isi_interface.py
""" test_generic_interface.py
Tests the generic interfaces of the profile and distance functions
Expand Down Expand Up @@ -88,6 +88,18 @@ def test_spike_distance():
check_func(spk.spike_distance)


def test_spike_sync_profile():
check_func(dist_from_prof(spk.spike_sync_profile))


def test_spike_sync():
check_func(spk.spike_sync)


if __name__ == "__main__":
test_isi_profile()
test_isi_distance()
test_spike_profile()
test_spike_distance()
test_spike_sync_profile()
test_spike_sync()

0 comments on commit a57f3d5

Please sign in to comment.