-
Notifications
You must be signed in to change notification settings - Fork 110
/
solution.py
147 lines (122 loc) · 4.83 KB
/
solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#!/bin/python
import math
import os
import random
import re
import sys
from collections import deque
class Node:
def __init__(self, value, nid):
self.val = value
self.children = []
self.cost = value
self.nid = nid
def constructTree(v, e):
# Create adjacency table
l_v = len(v)
tbl = {i:[] for i in range(l_v)}
discovered = [None for i in range(l_v)] # 0-based
for src, des in e:
src -= 1 # 1-base -> 0-base
des -= 1 # 1-base -> 0-base
tbl[src].append(des)
tbl[des].append(src)
# Initialize root node
discovered[0] = Node(v[0], 0)
stack = deque([(False, 0)]) # (type, id_zerob)
# Traversal and create root cost
while stack:
flag_agg, curr_id = stack.pop()
if flag_agg:
# At this moment, all child node's cost are already
# computed, so just sum them up.
# Leaf node will keep the cost unchanged.
pnode = discovered[curr_id]
for cnode in pnode.children:
pnode.cost += cnode.cost
else:
# Append aggregate flag
stack.append((True, curr_id))
# Append available childrens
for adj_id in tbl[curr_id]:
if discovered[adj_id] is None:
tmp_newnode = discovered[adj_id] = Node(v[adj_id], adj_id)
discovered[curr_id].children.append(tmp_newnode)
stack.append((False, adj_id))
# Return root node
return discovered[0]
_MAX_COST = 5 * (10 ** 13) + 1
def aux_traversal(root, size):
best = _MAX_COST
stack = deque([root])
path_table = [False for _ in range(size)]
path_table[0] = True
# Node for Do general traversal and label path
# int - Means we're leaving this node, and
# remove it from label path
while stack:
cur_node = stack.pop()
if type(cur_node) == int:
path_table[cur_node] = False
else:
path_table[cur_node.nid] = True
stack.append(cur_node.nid) # Insert exit flag
for chd_node in cur_node.children:
target_root = root.cost - chd_node.cost
target_split = chd_node.cost
# Corresponds to cases (3), (1), (2) in description above
res_equal = target_root if target_root == target_split else _MAX_COST
res_root = best_split(root, target_split, chd_node.nid, path_table)
res_split = best_split(chd_node, target_root)
best = min(best, res_equal, res_root, res_split)
if best == 0:
return 0
stack.append(chd_node)
return best if best < _MAX_COST else -1
def best_split(root, target, exclude=None, path=None):
# root: Root node of the tree to be traversed
# target: Target cost
# exclude: Node ID of excluded subtree's root node
# path: Table for traversed path. Used to deduct cost of
# excluded sub-tree
# Total cost of current tree
whole_cost = root.cost - (target if exclude is not None else 0)
# If current tree's cost is more than two times of target
# cost, then you need to insert negative value in order to
# balance them, which is not valid for our task.
# If current tree's cost is less than target,
if whole_cost > target * 2 or whole_cost <= target:
return _MAX_COST
best = _MAX_COST
stack = deque([root])
while stack:
cur_node = stack.pop()
for chd_node in cur_node.children:
if chd_node.nid != exclude:
sub_cost = chd_node.cost - (target if path and path[chd_node.nid] else 0)
root_cost = whole_cost - sub_cost
# If we can find an even split, then do early exit
if root_cost == sub_cost == target:
return 0
# Otherwise, find out if one of them is balanced
elif root_cost == target and sub_cost < target:
best = min(best, target - sub_cost)
elif sub_cost == target and root_cost < target:
best = min(best, target - root_cost)
stack.append(chd_node)
return best
def balancedForest(tree_values, tree_edges):
t = constructTree(tree_values, tree_edges)
return aux_traversal(t, len(tree_values))
if __name__ == '__main__':
fptr = open(os.environ['OUTPUT_PATH'], 'w')
q = int(raw_input())
for q_itr in xrange(q):
n = int(raw_input())
c = map(int, raw_input().rstrip().split())
edges = []
for _ in xrange(n - 1):
edges.append(map(int, raw_input().rstrip().split()))
result = balancedForest(c, edges)
fptr.write(str(result) + '\n')
fptr.close()