diff --git a/src/lineagetree/_basics/_navigation.py b/src/lineagetree/_basics/_navigation.py index 9a86ea2..7c23a94 100644 --- a/src/lineagetree/_basics/_navigation.py +++ b/src/lineagetree/_basics/_navigation.py @@ -360,23 +360,20 @@ def nodes_at_t( list of int list of ids of the nodes at time `t` spawned by `r` """ - if not r and r != 0: - r = {root for root in lT.roots if lT.time[root] <= t} + if r is None: + return lT.time_nodes.get(t, []) if isinstance(r, int): r = [r] if t is None: t = lT.t_e to_do = list(r) final_nodes = [] - while len(to_do) > 0: + while 0 < len(to_do): curr = to_do.pop() - for _next in lT._successor[curr]: - if lT._time[_next] < t: - to_do.append(_next) - elif lT._time[_next] == t: - final_nodes.append(_next) - if not final_nodes: - return list(r) + if lT._time[curr] == t: + final_nodes.append(curr) + elif lT._time[curr] < t: + to_do.extend(lT.successor[curr]) return final_nodes