Skip to content

Commit

Permalink
ADD: ARG.from_path and to_ts
Browse files Browse the repository at this point in the history
  • Loading branch information
tcztzy committed Mar 2, 2024
1 parent 238d6ca commit f710f35
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 1 deletion.
145 changes: 144 additions & 1 deletion argweavers/arg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,148 @@
import re
from pathlib import Path

import networkx as nx # type: ignore
import numpy as np # type: ignore
import pandas as pd # type: ignore
import tskit # type: ignore

__all__ = ["ARG"]


class ARG:
...
G: nx.DiGraph

@property
def sequence_length(self) -> int:
return self.G["sequence_length"]

@classmethod
def from_path(cls, path):
"""Construct an ARG from an ARGweaver file.
The `.arg` file start with a header line with the start and end positions of the sequence.
The rest of the file is a tab-separated table with the following columns:
- name: the name of the node
- event: the type of event (gene, recomb, coal)
- age: the age of the node
- pos: the position of the recombination event (if event is recomb)
- parents: the parents of the node
- children: the children of the node
"""
path = Path(path)
self = cls()
with open(path) as f:
line = next(f).strip()
mo = re.match(r"^start=(\d+)\tend=(\d+)$", line)
assert mo is not None

# the "name" field can be a string. Force it to be so, in case it is just numbers
df = pd.read_csv(
f, header=0, sep="\t", dtype={"name": str, "parents": str}, index_col=0
)
for col in ("parents", "age"):
if col not in df.columns:
raise ValueError(f"Column {col} not found in ARGweaver file")

names = sorted(df.query("event == 'gene'").index)
# Make an nx DiGraph so we can do a topological sort.
G = nx.DiGraph()
assert int(mo.group(1)) == 0
G["sequence_length"] = int(mo.group(2))
for child, row in df.iterrows():
try:
child = int(child)
except ValueError:
pass
if row["event"] == "gene":
G.add_node(
names.index(child),
age=row["age"],
event=row["event"],
sample=child,
)
elif row["event"] == "recomb":
G.add_node(
int(child), age=row["age"], event=row["event"], pos=row["pos"]
)
else:
G.add_node(int(child), age=row["age"], event=row["event"])
if isinstance(row["parents"], str):
for i, parent in enumerate(
[int(p) for p in row["parents"].split(",")]
):
if child in names:
child = names.index(child)
G.add_edge(child, parent, parent=i)
else:
G.add_node(child, root=True)
self.G = G
return self

def to_ts(self) -> tskit.TreeSequence:
tables = tskit.TableCollection(sequence_length=self.sequence_length)
tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json()
breakpoints = np.full(len(self.G), tables.sequence_length)
aw_to_tsk_id = {}
times = np.unique([a for _, a in self.G.nodes.data("age")])
time_map = {time: 1 if time == 0 else 0 for time in times}
min_time_diff = min(np.diff(times))
epsilon = min_time_diff / 1e6
try:
for node in nx.lexicographical_topological_sort(self.G):
record = self.G.nodes[node].copy()
record["name"] = record.get("sample", str(node))
age = record["age"]
flags = 0
# Sample nodes are marked as "gene" events
if record["event"] == "gene":
flags = tskit.NODE_IS_SAMPLE
assert age == 0
time = age
else:
time = age + time_map[age] * epsilon
# Argweaver allows age of parent and child to be the same, so we
# need to add epsilons to enforce parent_age > child_age
time_map[age] += 1
tsk_id = tables.nodes.add_row(flags=flags, time=time, metadata=record)
aw_to_tsk_id[node] = tsk_id
if record["event"] == "recomb":
breakpoints[tsk_id] = record["pos"]
except nx.exception.NetworkXUnfeasible:
bad_edges = nx.find_cycle(self.G, orientation="original")
raise nx.exception.NetworkXUnfeasible(
f"Cycle found in ARGweaver graph: {bad_edges}"
)

L = tables.sequence_length
for aw_node in self.G:
child = aw_to_tsk_id[aw_node]
parents = [
aw_to_tsk_id[aw_parent]
for _, aw_parent, _ in sorted(
self.G.edges(aw_node, data="parent"), key=lambda x: x[2]
)
]
if len(parents) == 1:
tables.edges.add_row(0, L, parents[0], child)
elif len(parents) == 2:
# Recombination node.
# Note that this uses the 1-RE-node convention
x = breakpoints[child]
tables.edges.add_row(0, x, parents[0], child)
tables.edges.add_row(x, L, parents[1], child)
else:
assert len(parents) == 0
tables.sort()
ts = tables.tree_sequence()
return ts.simplify(keep_unary=True)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("path", type=Path)
args = parser.parse_args()
ARG.from_path(args.path).to_ts().dump(args.path.with_suffix(".trees"))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"matplotlib>=3.7.5",
"biopython>=1.83",
"networkx>=3.1",
"tskit>=0.5.6",
]

[project.optional-dependencies]
Expand Down

0 comments on commit f710f35

Please sign in to comment.