Skip to content

Commit

Permalink
generic interface for spike distance/profile
Browse files Browse the repository at this point in the history
spike_profile and spike_distance now have a generic interface that allows
to compute bi-variate and multi-variate results with the same function.
  • Loading branch information
mariomulansky committed Jan 31, 2016
1 parent 5a556a1 commit ea3709e
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 35 deletions.
17 changes: 14 additions & 3 deletions pyspike/isi_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
# isi_profile
############################################################
def isi_profile(*args, **kwargs):
""" Computes the isi-distance profile :math:`I(t)` of the two given
""" Computes the isi-distance profile :math:`I(t)` of the given
spike trains. Returns the profile as a PieceWiseConstFunc object. The
ISI-values are defined positive :math:`I(t)>=0`.
Valid call structures::
isi_profile(st1, st2) # returns the bi-variate profile
isi_profile(st1, st2, st3) # multi-variate profile of 3 spike trains
spike_trains = [st1, st2, st3, st4] # list of spike trains
isi_profile(spike_trains) # return the profile the list of spike trains
isi_profile(spike_trains) # profile of the list of spike trains
isi_profile(spike_trains, indices=[0, 1]) # use only the spike trains
# given by the indices
Expand Down Expand Up @@ -108,13 +109,23 @@ def isi_profile_multi(spike_trains, indices=None):
# isi_distance
############################################################
def isi_distance(*args, **kwargs):
# spike_trains, spike_train2, interval=None):
""" Computes the ISI-distance :math:`D_I` of the given spike trains. The
isi-distance is the integral over the isi distance profile
:math:`I(t)`:
.. math:: D_I = \\int_{T_0}^{T_1} I(t) dt.
Valid call structures::
isi_distance(st1, st2) # returns the bi-variate distance
isi_distance(st1, st2, st3) # multi-variate distance of 3 spike trains
spike_trains = [st1, st2, st3, st4] # list of spike trains
isi_distance(spike_trains) # distance of the list of spike trains
isi_distance(spike_trains, indices=[0, 1]) # use only the spike trains
# given by the indices
:returns: The isi-distance :math:`D_I`.
:rtype: double
"""
Expand Down
124 changes: 93 additions & 31 deletions pyspike/spike_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,36 @@
############################################################
# spike_profile
############################################################
def spike_profile(spike_train1, spike_train2):
def spike_profile(*args, **kwargs):
""" Computes the spike-distance profile :math:`S(t)` of the given
spike trains. Returns the profile as a PieceWiseConstLin object. The
SPIKE-values are defined positive :math:`S(t)>=0`.
Valid call structures::
spike_profile(st1, st2) # returns the bi-variate profile
spike_profile(st1, st2, st3) # multi-variate profile of 3 spike trains
spike_trains = [st1, st2, st3, st4] # list of spike trains
spike_profile(spike_trains) # profile of the list of spike trains
spike_profile(spike_trains, indices=[0, 1]) # use only the spike trains
# given by the indices
:returns: The spike-distance profile :math:`S(t)`
:rtype: :class:`.PieceWiseConstLin`
"""
if len(args) == 1:
return spike_profile_multi(args[0], **kwargs)
elif len(args) == 2:
return spike_profile_bi(args[0], args[1])
else:
return spike_profile_multi(args)


############################################################
# spike_profile_bi
############################################################
def spike_profile_bi(spike_train1, spike_train2):
""" Computes the spike-distance profile :math:`S(t)` of the two given spike
trains. Returns the profile as a PieceWiseLinFunc object. The SPIKE-values
are defined positive :math:`S(t)>=0`.
Expand Down Expand Up @@ -53,10 +82,68 @@ def spike_profile(spike_train1, spike_train2):
return PieceWiseLinFunc(times, y_starts, y_ends)


############################################################
# spike_profile_multi
############################################################
def spike_profile_multi(spike_trains, indices=None):
""" Computes the multi-variate spike distance profile for a set of spike
trains. That is the average spike-distance of all pairs of spike-trains:
.. math:: <S(t)> = \\frac{2}{N(N-1)} \\sum_{<i,j>} S^{i, j}`,
where the sum goes over all pairs <i,j>
:param spike_trains: list of :class:`.SpikeTrain`
:param indices: list of indices defining which spike trains to use,
if None all given spike trains are used (default=None)
:type indices: list or None
:returns: The averaged spike profile :math:`<S>(t)`
:rtype: :class:`.PieceWiseLinFunc`
"""
average_dist, M = _generic_profile_multi(spike_trains, spike_profile_bi,
indices)
average_dist.mul_scalar(1.0/M) # normalize
return average_dist


############################################################
# spike_distance
############################################################
def spike_distance(spike_train1, spike_train2, interval=None):
def spike_distance(*args, **kwargs):
""" Computes the SPIKE-distance :math:`D_S` of the given spike trains. The
spike-distance is the integral over the spike distance profile
:math:`D(t)`:
.. math:: D_S = \\int_{T_0}^{T_1} S(t) dt.
Valid call structures::
spike_distance(st1, st2) # returns the bi-variate distance
spike_distance(st1, st2, st3) # multi-variate distance of 3 spike trains
spike_trains = [st1, st2, st3, st4] # list of spike trains
spike_distance(spike_trains) # distance of the list of spike trains
spike_distance(spike_trains, indices=[0, 1]) # use only the spike trains
# given by the indices
:returns: The spike-distance :math:`D_S`.
:rtype: double
"""

if len(args) == 1:
return spike_distance_multi(args[0], **kwargs)
elif len(args) == 2:
return spike_distance_bi(args[0], args[1], **kwargs)
else:
return spike_distance_multi(args, **kwargs)


############################################################
# spike_distance_bi
############################################################
def spike_distance_bi(spike_train1, spike_train2, interval=None):
""" Computes the spike-distance :math:`D_S` of the given spike trains. The
spike-distance is the integral over the spike distance profile
:math:`S(t)`:
Expand Down Expand Up @@ -86,35 +173,10 @@ def spike_distance(spike_train1, spike_train2, interval=None):
spike_train1.t_end)
except ImportError:
# Cython backend not available: fall back to average profile
return spike_profile(spike_train1, spike_train2).avrg(interval)
return spike_profile_bi(spike_train1, spike_train2).avrg(interval)
else:
# some specific interval is provided: compute the whole profile
return spike_profile(spike_train1, spike_train2).avrg(interval)


############################################################
# spike_profile_multi
############################################################
def spike_profile_multi(spike_trains, indices=None):
""" Computes the multi-variate spike distance profile for a set of spike
trains. That is the average spike-distance of all pairs of spike-trains:
.. math:: <S(t)> = \\frac{2}{N(N-1)} \\sum_{<i,j>} S^{i, j}`,
where the sum goes over all pairs <i,j>
:param spike_trains: list of :class:`.SpikeTrain`
:param indices: list of indices defining which spike trains to use,
if None all given spike trains are used (default=None)
:type indices: list or None
:returns: The averaged spike profile :math:`<S>(t)`
:rtype: :class:`.PieceWiseLinFunc`
"""
average_dist, M = _generic_profile_multi(spike_trains, spike_profile,
indices)
average_dist.mul_scalar(1.0/M) # normalize
return average_dist
return spike_profile_bi(spike_train1, spike_train2).avrg(interval)


############################################################
Expand All @@ -139,7 +201,7 @@ def spike_distance_multi(spike_trains, indices=None, interval=None):
:returns: The averaged multi-variate spike distance :math:`D_S`.
:rtype: double
"""
return _generic_distance_multi(spike_trains, spike_distance, indices,
return _generic_distance_multi(spike_trains, spike_distance_bi, indices,
interval)


Expand All @@ -160,5 +222,5 @@ def spike_distance_matrix(spike_trains, indices=None, interval=None):
:math:`D_S^{ij}`
:rtype: np.array
"""
return _generic_distance_matrix(spike_trains, spike_distance,
return _generic_distance_matrix(spike_trains, spike_distance_bi,
indices, interval)
31 changes: 30 additions & 1 deletion test/test_generic_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ def __init__(self, prof_func):
self.prof_func = prof_func

def __call__(self, *args, **kwargs):
return self.prof_func(*args, **kwargs).avrg()
if "interval" in kwargs:
# forward interval arg into avrg function
interval = kwargs.pop("interval")
return self.prof_func(*args, **kwargs).avrg(interval=interval)
else:
return self.prof_func(*args, **kwargs).avrg()


def check_func(dist_func):
Expand All @@ -50,6 +55,22 @@ def check_func(dist_func):
isi123_ = dist_func(spike_trains, indices=[0, 1, 2])
assert_equal(isi123, isi123_)

# run the same test with an additional interval parameter

isi12 = dist_func(t1, t2, interval=[0.0, 0.5])
isi12_ = dist_func([t1, t2], interval=[0.0, 0.5])
assert_equal(isi12, isi12_)

isi12_ = dist_func(spike_trains, indices=[0, 1], interval=[0.0, 0.5])
assert_equal(isi12, isi12_)

isi123 = dist_func(t1, t2, t3, interval=[0.0, 0.5])
isi123_ = dist_func([t1, t2, t3], interval=[0.0, 0.5])
assert_equal(isi123, isi123_)

isi123_ = dist_func(spike_trains, indices=[0, 1, 2], interval=[0.0, 0.5])
assert_equal(isi123, isi123_)


def test_isi_profile():
check_func(dist_from_prof(spk.isi_profile))
Expand All @@ -59,6 +80,14 @@ def test_isi_distance():
check_func(spk.isi_distance)


def test_spike_profile():
check_func(dist_from_prof(spk.spike_profile))


def test_spike_distance():
check_func(spk.spike_distance)


if __name__ == "__main__":
test_isi_profile()
test_isi_distance()
1 change: 1 addition & 0 deletions test/test_regression/test_regression_15.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TEST_PATH = os.path.dirname(os.path.realpath(__file__))
TEST_DATA = os.path.join(TEST_PATH, "..", "SPIKE_Sync_Test.txt")


def test_regression_15_isi():
# load spike trains
spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=[0, 4000])
Expand Down

0 comments on commit ea3709e

Please sign in to comment.