-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathACRStateMPPAEFTTest.py
executable file
·103 lines (87 loc) · 4.37 KB
/
ACRStateMPPAEFTTest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import unittest
from collections import Counter
import pandas as pd
import numpy as np
from pastml.models.EFTModel import EFT
from pastml.tree import read_tree, collapse_zero_branches
from pastml.acr import acr
from pastml.ml import MPPA
DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR, 'Albanian.tree.152tax.tre')
STATES_INPUT = os.path.join(DATA_DIR, 'data.txt')
feature = 'Country'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[feature]]
tree = read_tree(TREE_NWK)
acr(tree, df, prediction_method=MPPA, model=EFT)
class ACRStateMPPAEFTTest(unittest.TestCase):
def test_collapsed_vs_full(self):
tree_uncollapsed = read_tree(TREE_NWK)
acr(tree_uncollapsed, df, prediction_method=MPPA, model=EFT)
def get_state(node):
return ', '.join(sorted(getattr(node, feature)))
df_full = pd.DataFrame.from_dict({node.name: get_state(node) for node in tree_uncollapsed.traverse()},
orient='index', columns=['full'])
df_collapsed = pd.DataFrame.from_dict({node.name: get_state(node) for node in tree.traverse()},
orient='index', columns=['collapsed'])
df_joint = df_collapsed.join(df_full, how='left')
self.assertTrue(np.all((df_joint['collapsed'] == df_joint['full'])),
msg='All the node states of the collapsed tree should be the same as of the full one.')
def test_num_nodes(self):
state2num = Counter()
root = tree.copy()
collapse_zero_branches([root])
for node in root.traverse():
state = getattr(node, feature)
if len(state) > 1:
state2num['unresolved'] += 1
else:
state2num[next(iter(state))] += 1
expected_state2num = {'unresolved': 8, 'Africa': 109, 'Albania': 50, 'Greece': 65, 'WestEurope': 29, 'EastEurope': 16}
self.assertDictEqual(expected_state2num, state2num, msg='Was supposed to have {} as states counts, got {}.'
.format(expected_state2num, state2num))
def test_state_root(self):
expected_state = {'Africa'}
state = getattr(tree, feature)
self.assertSetEqual(expected_state, state,
msg='Root state was supposed to be {}, got {}.'.format(expected_state, state))
def test_state_unresolved_internal_node(self):
expected_state = {'Africa', 'Greece'}
for node in tree.traverse():
if 'node_79' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break
def test_state_node_32(self):
expected_state = {'WestEurope'}
for node in tree.traverse():
if 'node_32' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break
def test_state_resolved_internal_node(self):
expected_state = {'Greece'}
for node in tree.traverse():
if 'node_80' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break
def test_state_zero_tip(self):
expected_state = {'Albania'}
for node in tree.traverse():
if '01ALAY1715' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break
def test_state_tip(self):
expected_state = {'WestEurope'}
for node in tree:
if '94SEAF9671' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break