diff --git a/btmorph2/auxFuncs.py b/btmorph2/auxFuncs.py index 92833fe..d56b7a1 100644 --- a/btmorph2/auxFuncs.py +++ b/btmorph2/auxFuncs.py @@ -1,6 +1,7 @@ import numpy as np from collections import Counter + # ********************************************************************************************************************** def readSWC_numpy(swcFile): @@ -140,4 +141,55 @@ def transSWC_rotAboutPoint(fName, A, b, destFle, point): raise TypeError('Data in the input file is of unknown format.') np.savetxt(destFle, data, header=headr, fmt=formatStr) -#*********************************************************************************************************************** \ No newline at end of file +#*********************************************************************************************************************** + + +def getIntersectionXYZs(p1, p2, centeredAt, radius): + """ + Calculates and returns the points of intersection between the line joining the points p1 and p2 + and the circle centered at centeredAt with radius radius. The points are ordered as they would be encountered when + moving from p1 to c2 + :param p1: 3 member float iterable + :param p2: 3 member float iterable + :param centeredAt: 3 member float iterable + :param radius: float + :return: iterable of intersections, each interesection being a 3 member float iterable + """ + + # Solving |x + alpha *y| = radius where + # x is the vector p1 - centeredAt, + # y is the vector c2 - centeredAt, + # alpha is a float in [0, 1] + + p1 = np.array(p1) + p2 = np.array(p2) + centeredAt = np.array(centeredAt) + + x = p1 - centeredAt + y = p2 - centeredAt + + modx = np.linalg.norm(x) + + yMx = y - x + mod_yMx = np.linalg.norm(yMx) + + xDotyMx = np.dot(x, yMx) + + # the above problem reduces to solving A * alpha^2 + B * alpha + C = 0 where + + A = mod_yMx ** 2 + B = 2 * xDotyMx + C = modx ** 2 - radius ** 2 + + roots = np.roots([A, B, C]) + roots = np.round(roots, 3) + if all(np.isreal(roots)): + roots = np.sort(roots) + + intersections = [(x + alpha * (y - x)).tolist() for alpha in roots if np.isreal(alpha) and 1 >= alpha >= 0] + + return intersections + +#*********************************************************************************************************************** + + diff --git a/btmorph2/structs/neuronMorph.py b/btmorph2/structs/neuronMorph.py index 768c96b..20626ff 100644 --- a/btmorph2/structs/neuronMorph.py +++ b/btmorph2/structs/neuronMorph.py @@ -10,6 +10,7 @@ import numpy as np import sys from .defaults import defaultGlobalScalarFuncs +from ..auxFuncs import getIntersectionXYZs class NeuronMorphology(object): ''' @@ -616,7 +617,7 @@ def avg_bif_angle_local(self): bifAngles = [self.bifurcation_angle_vec(n, where="local") for n in self._bif_points] return float(np.mean(bifAngles)), bifAngles else: - return float('nan') + return float('nan'), [float("nan")] def avg_bif_angle_remote(self): """ @@ -631,7 +632,7 @@ def avg_bif_angle_remote(self): bifAngles = [self.bifurcation_angle_vec(n, where="remote") for n in self._bif_points] return float(np.mean(bifAngles)), bifAngles else: - return float('nan') + return float('nan'), [float("nan")] def avg_partition_asymmetry(self): """ @@ -699,7 +700,7 @@ def avg_sibling_ratio_local(self): return float(np.mean(siblingRatios_local)), siblingRatios_local else: - return float("nan") + return float("nan"), [float("nan")] """ Local measures @@ -1297,3 +1298,174 @@ def getGlobalScalarMeasures(self, funcs=None): return swcDataS + def getIntersectionsVsDistance(self, radii, centeredAt=None): + """ + Calculates and returns the number of intersections of the morphology with concentric spheres of radii in + input argument radii + :param radii: iterable of non-negative floats of size at least 2, radii of spheres concerned + :param centeredAt: iterable of floats of size 3, containing the X, Y and Z coordinates of the center of the + spheres + :return: list of same size as radii, of intersections, corresponding to radii in input argument radii + """ + + assert len(radii) >= 2, "Input argument radii must have at least 2 numbers, got {}".format(radii) + assert all([x >= 0 for x in radii]), "Input argument radii can only consist of non-negative numbers, " \ + "got {}".format(radii) + intersects = [0 for x in radii] + radii = list(radii) + radiiSorted = np.sort(radii) + + if centeredAt is None: + centeredAt = self.get_tree().root.content["p3d"].xyz + + assert len(centeredAt) == 3, "Input argument centeredAt must be a 3 member iterable of numbers, " \ + "got {}".format(centeredAt) + centeredAt = np.asarray(centeredAt) + + def nodeDistance(n): + nXYZ = np.asarray(n.content["p3d"].xyz) + return np.linalg.norm(nXYZ - centeredAt) + + def nodeXYZ(n): + return np.asarray(n.content["p3d"].xyz) + + allNodesExceptRoot = [x for x in self.get_tree().breadth_first_iterator_generator() if x.parent] + + for node in allNodesExceptRoot: + + nodeDist = nodeDistance(node) + parentDist = nodeDistance(node.parent) + + # in case node is farther than the parent + if nodeDist > parentDist: + fartherDist = nodeDist + fartherXYZ = nodeXYZ(node) + nearerDist = parentDist + nearerXYZ = nodeXYZ(node.parent) + else: + fartherDist = parentDist + fartherXYZ = nodeXYZ(node.parent) + nearerDist = nodeDist + nearerXYZ = nodeXYZ(node) + + if fartherDist > radiiSorted[0]: + radiiCrossedMask = np.logical_and(nearerDist < radiiSorted, radiiSorted <= fartherDist) + radiiCrossed = radiiSorted[radiiCrossedMask] + + if len(radiiCrossed) == 0: + radiiCrossed = [] + largestRLessNearerPoint = radiiSorted[radiiSorted <= nearerDist].max() + currentIntersects = getIntersectionXYZs(nearerXYZ, fartherXYZ, centeredAt, + largestRLessNearerPoint) + + if len(currentIntersects) == 2: + if nearerDist > largestRLessNearerPoint: + radiiCrossed.append(largestRLessNearerPoint) + if fartherDist > largestRLessNearerPoint: + radiiCrossed.append(largestRLessNearerPoint) + for rad in radiiCrossed: + intersects[radii.index(rad)] += 1 + + return intersects + + def getLengthVsDistance(self, radii, centeredAt=None): + """ + Calculates and returns the length of dendrites of the morphology contained within concentric shells defined by + adjacent values of input argument radii. First shell is the sphere of radius radii[0] + :param radii: iterable of positive floats of size at least 2, radii "bin edges" of shells + :param centeredAt: iterable of floats of size 3, containing the X, Y and Z coordinates of the center of the + spheres + :return: list of size len(radii), of lengths, corresponding to concentric shells defined by adjacent values + of radii. + """ + + assert len(radii) >= 2, "Input argument radii must have at least 2 numbers, got {}".format(radii) + assert all(x > 0 for x in radii), "Input argument radii can only consist of postive numbers, " \ + "got {}".format(radii) + radii = list(radii) + lengths = [0 for x in radii] + assert radii == sorted(radii), "Input argument radii must be sorted" + radiiSorted = np.array(radii) + + if centeredAt is None: + centeredAt = self.get_tree().root.content["p3d"].xyz + + assert len(centeredAt) == 3, "Input argument centeredAt must be a 3 member iterable of numbers, " \ + "got {}".format(centeredAt) + centeredAt = np.asarray(centeredAt) + + def nodeDistance(n): + nXYZ = np.asarray(n.content["p3d"].xyz) + return np.linalg.norm(nXYZ - centeredAt) + + def nodeXYZ(n): + return np.asarray(n.content["p3d"].xyz) + + allNodesExceptRoot = [x for x in self.get_tree().breadth_first_iterator_generator() if x.parent] + + for node in allNodesExceptRoot: + + nodeDist = nodeDistance(node) + parentDist = nodeDistance(node.parent) + + # both points are within first shell + if nodeDist <= radiiSorted[0] and parentDist <= radiiSorted[0]: + lengths[0] += np.linalg.norm(nodeXYZ(node) - nodeXYZ(node.parent)) + + else: + if nodeDist < parentDist: + nearerPoint = nodeXYZ(node) + nearerDist = nodeDist + fartherPoint = nodeXYZ(node.parent) + fartherDist = parentDist + else: + nearerPoint = nodeXYZ(node.parent) + nearerDist = parentDist + fartherPoint = nodeXYZ(node) + fartherDist = nodeDist + + radiiCrossedMask = np.logical_and(nearerDist < radiiSorted, radiiSorted <= fartherDist) + radiiCrossed = radiiSorted[radiiCrossedMask] + + # line connecting the points are within one shell + if len(radiiCrossed) == 0: + + largestRLessNearerPoint = radiiSorted[radiiSorted <= nearerDist].max() + intersects = getIntersectionXYZs(nearerPoint, fartherPoint, centeredAt, largestRLessNearerPoint) + + # line joining the points does not intersect any sphere + if len(intersects) == 0: + shellIndex = radii.index(largestRLessNearerPoint) + 1 + lengths[shellIndex] += np.linalg.norm(fartherPoint - nearerPoint) + # line joining the points intersects a sphere are two distinct points + elif len(intersects) == 2: + intersects = np.array(intersects) + innerShellIndex = radii.index(largestRLessNearerPoint) + outerShellIndex = radii.index(largestRLessNearerPoint) + 1 + lengths[outerShellIndex] += np.linalg.norm(intersects[0] - nearerPoint) + lengths[outerShellIndex] += np.linalg.norm(fartherPoint - intersects[1]) + lengths[innerShellIndex] += np.linalg.norm(intersects[1] - intersects[0]) + + else: + raise(ValueError("Impossible case! There has been a wrong assumption.")) + + # line connecting the points is contained in at least two shells + else: + tempNearestPoint = nearerPoint + for rad in radiiCrossed: + shellIndex = radii.index(rad) + intersects = getIntersectionXYZs(nearerPoint, fartherPoint, centeredAt, rad) + assert len(intersects) == 1, "Impossible case! There has been a wrong assumption." + intersect = np.array(intersects[0]) + lengths[shellIndex] += np.linalg.norm(intersect - tempNearestPoint) + tempNearestPoint = intersect + if shellIndex + 1 < len(radii): + lengths[shellIndex + 1] += np.linalg.norm(fartherPoint - tempNearestPoint) + + return lengths + + + + + + diff --git a/tests/auxFuncs_test.py b/tests/auxFuncs_test.py new file mode 100644 index 0000000..e6cb2df --- /dev/null +++ b/tests/auxFuncs_test.py @@ -0,0 +1,61 @@ +from btmorph2.auxFuncs import getIntersectionXYZs +import numpy as np + + +def getPointAtDistance_test(): + """ + Testing getIntersectionXYZs function + """ + + # case1, two equal real intersections + parentXYZ = [20, 0, 0] + childXYZ = [20, -10, 0] + centeredAt = [0, 0, 0] + radius = 20 + expectedIntersects = [[20, 0, 0], [20, 0, 0]] + intersects = getIntersectionXYZs(parentXYZ, childXYZ, centeredAt, radius) + assert np.shape(expectedIntersects) == np.shape(intersects) and np.allclose(expectedIntersects, intersects, + atol=1e-2) + + # case2, one real intersections + parentXYZ = [10, 0, 0] + childXYZ = [20, -10, 0] + centeredAt = [0, 0, 0] + radius = 15 + expectedIntersects = [[14.35, -4.35, 0.0]] + intersects = getIntersectionXYZs(parentXYZ, childXYZ, centeredAt, radius) + assert np.shape(expectedIntersects) == np.shape(intersects) and np.allclose(expectedIntersects, intersects, + atol=1e-2) + + # case 3, one real intersections + parentXYZ = [20, 0, 0] + childXYZ = [30, 10, 0] + centeredAt = [0, 0, 0] + radius = 25 + expectedIntersects = [[24.58, 4.58, 0.0]] + intersects = getIntersectionXYZs(parentXYZ, childXYZ, centeredAt, radius) + assert np.shape(expectedIntersects) == np.shape(intersects) and np.allclose(expectedIntersects, intersects, + atol=1e-2) + + # case 4, no intersections + parentXYZ = [50, -10, 0] + childXYZ = [50, -20, 0] + centeredAt = [0, 0, 0] + radius = 50 + expectedIntersects = [] + intersects = getIntersectionXYZs(parentXYZ, childXYZ, centeredAt, radius) + assert np.shape(expectedIntersects) == np.shape(intersects) and np.allclose(expectedIntersects, intersects, + atol=1e-2) + + # case 5, two unequal real intersections + parentXYZ = [2.5, 7.5, 0] + childXYZ = [2.5, -7.5, 0] + centeredAt = [0, 0, 0] + radius = 5 + expectedIntersects = [[2.5, 4.33, 0.0], [2.5, -4.33, 0.0]] + intersects = getIntersectionXYZs(parentXYZ, childXYZ, centeredAt, radius) + assert np.shape(expectedIntersects) == np.shape(intersects) and np.allclose(expectedIntersects, intersects, + atol=1e-2) + +if __name__ == "__main__": + intersects = getIntersectionXYZs(p1=[20, 30, 0], p2=[30, 40, 0], radius=50, centeredAt=[0, 0, 0]) \ No newline at end of file diff --git a/tests/horton-strahler_test_wiki.png b/tests/horton-strahler_test_wiki.png new file mode 100644 index 0000000..9d048bb Binary files /dev/null and b/tests/horton-strahler_test_wiki.png differ diff --git a/tests/sholl_test.png b/tests/sholl_test.png new file mode 100644 index 0000000..a9e0c92 Binary files /dev/null and b/tests/sholl_test.png differ diff --git a/tests/sholl_test.swc b/tests/sholl_test.swc new file mode 100644 index 0000000..6e6b32d --- /dev/null +++ b/tests/sholl_test.swc @@ -0,0 +1,26 @@ +# Horton-Strahler index example from wikipedia +#http://en.wikipedia.org/wiki/Strahler_number +#http://en.wikipedia.org/wiki/File:Flussordnung_%28Strahler%29.svg +#For testing purposes +1 1 -1.0 -1.0 -1.0 1.9 -1 +2 3 0.0 10 0.0 1 1 +3 3 -10.0 20.0 0.0 1 2 +4 3 0.0 20.0 0.0 1 2 +5 3 10.0 30.0 0.0 1 4 +6 3 20.0 30.0 0.0 1 5 +7 3 30.0 30.0 0.0 1 6 +8 3 30.0 40.0 0.0 1 6 +9 3 10.0 40.0 0.0 1 5 +10 3 20.0 40.0 0.0 1 9 +11 3 10.0 50.0 0.0 1 9 +12 3 0.0 30.0 0.0 1 4 +13 3 -10.0 40.0 0.0 1 12 +14 3 0.0 40.0 0.0 1 12 +15 3 -10.0 50.0 0.0 1 14 +16 3 -20.0 50.0 0.0 1 15 +17 3 -10.0 60.0 0.0 1 15 +18 3 0.0 50.0 0.0 1 14 +19 3 0.0 60.0 0.0 1 18 +20 3 10.0 60.0 0.0 1 18 +21 3 12.5 -2.5 0.0 1 2 +22 3 -17.5 42.5 0.0 1 16 \ No newline at end of file diff --git a/tests/structs_test.py b/tests/structs_test.py index 0af3e59..df18e8c 100644 --- a/tests/structs_test.py +++ b/tests/structs_test.py @@ -492,3 +492,44 @@ def getGlobalScalarMeasures_NM_test(): } assert all([np.allclose(tests_globalFeatures[x], expectedGlobalFeatures[x], atol=1e-3) for x in expectedGlobalFeatures.keys()]) + + +def sholl_test(): + """ + Testing the getIntersectionsVsDistance function of Neuron Morphology + :return: + """ + + swcFile = "tests/sholl_test.swc" + nrn = NeuronMorphology(swcFile, correctIfSomaAbsent=True) + radii = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60] + centeredAt = [0, 0, 0] + intersects = nrn.getIntersectionsVsDistance(radii=radii, centeredAt=centeredAt) + + expectedIntersects = [1, 2, 2, 2, 2, 2, 4, 5, 4, 5, 3, 3] + + assert np.allclose(intersects, expectedIntersects, atol=0) + + +def lengthVsDistance_test(): + """ + Testing the getLengthVsDistance function of NeuronMorphology + :return: + """ + + swcFile = "tests/sholl_test.swc" + nrn = NeuronMorphology(swcFile, correctIfSomaAbsent=True) + radii = range(5, 65, 5) + centeredAt = [0, 0, 0] + lengths = nrn.getLengthVsDistance(radii=radii, centeredAt=centeredAt) + expectedLengths = [9.8110707864739055, 19.221601343666464, 14.6873629022557, + 10.487148622007609, 13.980256121069152, 10.840702012600886, + 25.026997102991103, 28.703342699720178, 36.90289680818843, + 25.45313963275283, 31.924143237189671, 16.429818895055199] + + assert np.allclose(lengths, expectedLengths) + + + +# if __name__ == "__main__": +# lengthVsDistance_test()