Skip to content

Commit

Permalink
updated formatting issues and docs in birthdeathfitnessimulator
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjones315 committed May 8, 2024
1 parent 77fc89b commit 283cdc7
Showing 1 changed file with 57 additions and 37 deletions.
94 changes: 57 additions & 37 deletions cassiopeia/simulator/BirthDeathFitnessSimulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
process, including differing fitness on lineages on the tree. Allows for a
variety of division and fitness regimes to be specified by the user.
"""

from typing import Callable, Dict, Generator, List, Optional, Union

import networkx as nx
Expand Down Expand Up @@ -153,17 +154,28 @@ def __init__(
self.initial_tree = initial_tree

def initialize_tree(self, names) -> nx.DiGraph:
"""initializes a tree (nx.DiGraph() object with one node). Auxiliary data for each node is grabbed from self (initial conditions / params) or hardcoded
"""Initializes a tree.
Initializes a tree (nx.DiGraph() object with one node). Auxiliary data
for each node is grabbed from self (initial conditions / params) or
hardcoded.
Args: names is a generator (function object that stores internal state) that will be used to generate names for the tree nodes
Output: tree (DiGraph object with one node, the root)
root (name of root node in tree)
Args:
names: A generator (function object that stores internal state) that
will be used to generate names for the tree nodes
Returns:
tree (DiGraph object with one node, the root) and root
(name of root node in tree)
"""
if self.initial_tree:
tree = self.initial_tree.get_tree_topology()
for node in self.initial_tree.nodes:
tree.nodes[node]['birth_scale'] = self.initial_tree.get_attribute(node, 'birth_scale')
tree.nodes[node]['time'] = self.initial_tree.get_attribute(node, 'time')
tree.nodes[node]["birth_scale"] = (
self.initial_tree.get_attribute(node, "birth_scale")
)
tree.nodes[node]["time"] = self.initial_tree.get_attribute(
node, "time"
)
return tree

tree = nx.DiGraph()
Expand All @@ -175,32 +187,34 @@ def initialize_tree(self, names) -> nx.DiGraph:
return tree

def make_initial_lineage_dict(self, tree: nx.DiGraph):
"""
uses self initial-conditions and hardcoded default parameters to create an initial lineage dict
Args: id_value: name of new lineage
"""Makes initial lineage queue.
Uses self initial-conditions and hardcoded default parameters to create
an initial lineage dict
Output: a lineage dict
Args:
id_value: name of new lineage
Returns:
A lineage dict
"""

leaves = [node for node in tree if tree.out_degree(node) == 0]
current_lineages = PriorityQueue()
for leaf in leaves:

lineage_dict = self.make_lineage_dict(
leaf, tree.nodes[leaf]['birth_scale'], tree.nodes[leaf]['time'], True
)

leaf,
tree.nodes[leaf]["birth_scale"],
tree.nodes[leaf]["time"],
True,
)

if len(tree.nodes) == 1:
return lineage_dict

current_lineages.put(
(
tree.nodes[leaf]['time'],
leaf,
lineage_dict
)
)


current_lineages.put((tree.nodes[leaf]["time"], leaf, lineage_dict))

return current_lineages

def make_lineage_dict(
Expand All @@ -212,12 +226,14 @@ def make_lineage_dict(
):
"""makes a dict (lineage) from the given parameters. keys are hardcoded.
Args: id_value: id of new lineage
birth_scale: birth_scale parameter of new lineage
total_time: age of lineage
active_flag: bool to indicate whether lineage is active
Args:
id_value: id of new lineage
birth_scale: birth_scale parameter of new lineage
total_time: age of lineage
active_flag: bool to indicate whether lineage is active
Returns: a dict (lineage) with the parameter values under the hard-coded keys
Returns:
A dict (lineage) with the parameter values under the hard-coded keys
"""
lineage_dict = {
Expand Down Expand Up @@ -250,7 +266,7 @@ def simulate_tree(
TreeSimulatorError if all lineages die before a stopping condition
"""

def node_name_generator(start = 0) -> Generator[str, None, None]:
def node_name_generator(start=0) -> Generator[str, None, None]:
"""Generates unique node names for the tree."""
i = start
while True:
Expand All @@ -259,23 +275,25 @@ def node_name_generator(start = 0) -> Generator[str, None, None]:

starting_index = 0
if self.initial_tree:
starting_index = (np.max([int(l) for l in self.initial_tree.leaves]) + 1)
starting_index = (
np.max([int(l) for l in self.initial_tree.leaves]) + 1
)
names = node_name_generator(starting_index)

# Set the seed
if self.random_seed:
np.random.seed(self.random_seed)

tree = self.initialize_tree(names)

current_lineages = PriorityQueue() # instantiate queue
current_lineages = PriorityQueue() # instantiate queue
# Records the nodes that are observed at the end of the experiment

# TO DO: update to accept arbitrary fields in the dict.
observed_nodes = []

starting_lineage = self.make_initial_lineage_dict(tree)

if len(tree.nodes) == 1:
# Sample the waiting time until the first division
self.sample_lineage_event(
Expand Down Expand Up @@ -444,7 +462,6 @@ def sample_lineage_event(
)
)


else:
tree.add_node(unique_id)
tree.nodes[unique_id]["birth_scale"] = lineage["birth_scale"]
Expand Down Expand Up @@ -509,7 +526,8 @@ def populate_tree_from_simulation(
"""Populates tree with appropriate meta data.
Args:
tree: The tree simulated with ecDNA and fitness values populated as attributes.
tree: The tree simulated with ecDNA and fitness values populated as
attributes.
observed_nodes: The observed leaves of the tree.
Returns:
Expand All @@ -521,7 +539,9 @@ def populate_tree_from_simulation(
time_dictionary = {}
for i in tree.nodes:
time_dictionary[i] = tree.nodes[i]["time"]
cassiopeia_tree.set_attribute(i, 'birth_scale', tree.nodes[i]['birth_scale'])
cassiopeia_tree.set_attribute(
i, "birth_scale", tree.nodes[i]["birth_scale"]
)
cassiopeia_tree.set_times(time_dictionary)

# Prune dead lineages and collapse resulting unifurcations
Expand Down

0 comments on commit 283cdc7

Please sign in to comment.