-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathACRStateMPPAJCTest.py
executable file
·122 lines (104 loc) · 5.19 KB
/
ACRStateMPPAJCTest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import unittest
from collections import Counter
import pandas as pd
import numpy as np
from pastml.models.JCModel import JC
from pastml.tree import read_tree, collapse_zero_branches
from pastml.acr import acr
from pastml.ml import MPPA
DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR, 'Albanian.tree.152tax.tre')
STATES_INPUT = os.path.join(DATA_DIR, 'data.txt')
feature = 'Country'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[feature]]
tree = read_tree(TREE_NWK)
acr(tree, df, prediction_method=MPPA, model=JC)
class ACRStateMPPAJCTest(unittest.TestCase):
def test_collapsed_vs_full(self):
tree_uncollapsed = read_tree(TREE_NWK)
acr(tree_uncollapsed, df, prediction_method=MPPA, model=JC)
def get_state(node):
return ', '.join(sorted(getattr(node, feature)))
df_full = pd.DataFrame.from_dict({node.name: get_state(node) for node in tree_uncollapsed.traverse()},
orient='index', columns=['full'])
df_collapsed = pd.DataFrame.from_dict({node.name: get_state(node) for node in tree.traverse()},
orient='index', columns=['collapsed'])
df_joint = df_collapsed.join(df_full, how='left')
self.assertTrue(np.all((df_joint['collapsed'] == df_joint['full'])),
msg='All the node states of the collapsed tree should be the same as of the full one.')
def test_num_nodes(self):
state2num = Counter()
root = tree.copy()
collapse_zero_branches([root])
for node in root.traverse():
state = getattr(node, feature)
if len(state) > 1:
state2num['unresolved'] += 1
else:
state2num[next(iter(state))] += 1
expected_state2num = {'unresolved': 9, 'Africa': 110, 'Albania': 50, 'Greece': 65, 'WestEurope': 27,
'EastEurope': 16}
self.assertDictEqual(expected_state2num, state2num, msg='Was supposed to have {} as states counts, got {}.'
.format(expected_state2num, state2num))
def test_state_root(self):
expected_state = {'Africa'}
state = getattr(tree, feature)
self.assertSetEqual(expected_state, state,
msg='Root state was supposed to be {}, got {}.'.format(expected_state, state))
def test_state_node_79(self):
expected_state = {'Africa', 'Greece'}
for node in tree.traverse():
if 'node_79' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break
def test_state_node_32(self):
expected_state = {'WestEurope', 'Greece'}
for node in tree.traverse():
if 'node_32' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break
def test_state_node_80(self):
expected_state = {'Greece'}
for node in tree.traverse():
if 'node_80' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break
def test_state_node_25(self):
expected_state = {'Greece', 'WestEurope'}
for node in tree.traverse():
if 'node_25' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break
def test_state_node_48(self):
expected_state = {'Africa', 'Greece', 'WestEurope'}
for node in tree.traverse():
if 'node_48' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break
def test_state_zero_tip(self):
expected_state = {'Albania'}
for node in tree.traverse():
if '01ALAY1715' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break
def test_state_tip(self):
expected_state = {'WestEurope'}
for node in tree:
if '94SEAF9671' == node.name:
state = getattr(node, feature)
self.assertSetEqual(expected_state, state, msg='{} state was supposed to be {}, got {}.'
.format(node.name, expected_state, state))
break