diff --git a/Python/optimal_binary_search_tree.py b/Python/optimal_binary_search_tree.py new file mode 100644 index 0000000..550a8e9 --- /dev/null +++ b/Python/optimal_binary_search_tree.py @@ -0,0 +1,142 @@ +# This Python program implements an optimal binary search tree (abbreviated BST) +# building dynamic programming algorithm that delivers O(n^2) performance. +# +# The goal of the optimal BST problem is to build a low-cost BST for a +# given set of nodes, each with its own key and frequency. The frequency +# of the node is defined as how many time the node is being searched. +# The search cost of binary search tree is given by this formula: +# +# cost(1, n) = sum{i = 1 to n}((depth(node_i) + 1) * node_i_freq) +# +# where n is number of nodes in the BST. The characteristic of low-cost +# BSTs is having a faster overall search time than other implementations. +# The reason for their fast search time is that the nodes with high +# frequencies will be placed near the root of the tree while the nodes +# with low frequencies will be placed near the leaves of the tree thus +# reducing search time in the most frequent instances. +import sys +from random import randint + + +class Node: + """Binary Search Tree Node""" + + def __init__(self, key, freq): + self.key = key + self.freq = freq + + def __str__(self): + """ + >>> str(Node(1, 2)) + 'Node(key=1, freq=2)' + """ + return f"Node(key={self.key}, freq={self.freq})" + + +def print_binary_search_tree(root, key, i, j, parent, is_left): + """ + Recursive function to print a BST from a root table. + + >>> key = [3, 8, 9, 10, 17, 21] + >>> root = [[0, 1, 1, 1, 1, 1], [0, 1, 1, 1, 1, 3], [0, 0, 2, 3, 3, 3], \ + [0, 0, 0, 3, 3, 3], [0, 0, 0, 0, 4, 5], [0, 0, 0, 0, 0, 5]] + >>> print_binary_search_tree(root, key, 0, 5, -1, False) + 8 is the root of the binary search tree. + 3 is the left child of key 8. + 10 is the right child of key 8. + 9 is the left child of key 10. + 21 is the right child of key 10. + 17 is the left child of key 21. + """ + if i > j or i < 0 or j > len(root) - 1: + return + + node = root[i][j] + if parent == -1: # root does not have a parent + print(f"{key[node]} is the root of the binary search tree.") + elif is_left: + print(f"{key[node]} is the left child of key {parent}.") + else: + print(f"{key[node]} is the right child of key {parent}.") + + print_binary_search_tree(root, key, i, node - 1, key[node], True) + print_binary_search_tree(root, key, node + 1, j, key[node], False) + + +def find_optimal_binary_search_tree(nodes): + """ + This function calculates and prints the optimal binary search tree. + The dynamic programming algorithm below runs in O(n^2) time. + Implemented from CLRS (Introduction to Algorithms) book. + https://en.wikipedia.org/wiki/Introduction_to_Algorithms + + >>> find_optimal_binary_search_tree([Node(12, 8), Node(10, 34), Node(20, 50), \ + Node(42, 3), Node(25, 40), Node(37, 30)]) + Binary search tree nodes: + Node(key=10, freq=34) + Node(key=12, freq=8) + Node(key=20, freq=50) + Node(key=25, freq=40) + Node(key=37, freq=30) + Node(key=42, freq=3) + + The cost of optimal BST for given tree nodes is 324. + 20 is the root of the binary search tree. + 10 is the left child of key 20. + 12 is the right child of key 10. + 25 is the right child of key 20. + 37 is the right child of key 25. + 42 is the right child of key 37. + """ + # Tree nodes must be sorted first, the code below sorts the keys in + # increasing order and rearrange its frequencies accordingly. + nodes.sort(key=lambda node: node.key) + + n = len(nodes) + + keys = [nodes[i].key for i in range(n)] + freqs = [nodes[i].freq for i in range(n)] + + # This 2D array stores the overall tree cost (which's as minimized as possible); + # for a single key, cost is equal to frequency of the key. + dp = [[freqs[i] if i == j else 0 for j in range(n)] for i in range(n)] + # sum[i][j] stores the sum of key frequencies between i and j inclusive in nodes + # array + total = [[freqs[i] if i == j else 0 for j in range(n)] for i in range(n)] + # stores tree roots that will be used later for constructing binary search tree + root = [[i if i == j else 0 for j in range(n)] for i in range(n)] + + for interval_length in range(2, n + 1): + for i in range(n - interval_length + 1): + j = i + interval_length - 1 + + dp[i][j] = sys.maxsize # set the value to "infinity" + total[i][j] = total[i][j - 1] + freqs[j] + + # Apply Knuth's optimization + # Loop without optimization: for r in range(i, j + 1): + for r in range(root[i][j - 1], root[i + 1][j] + 1): # r is a temporal root + left = dp[i][r - 1] if r != i else 0 # optimal cost for left subtree + right = dp[r + 1][j] if r != j else 0 # optimal cost for right subtree + cost = left + total[i][j] + right + + if dp[i][j] > cost: + dp[i][j] = cost + root[i][j] = r + + print("Binary search tree nodes:") + for node in nodes: + print(node) + + print(f"\nThe cost of optimal BST for given tree nodes is {dp[0][n - 1]}.") + print_binary_search_tree(root, keys, 0, n - 1, -1, False) + + +def main(): + # A sample binary search tree + nodes = [Node(i, randint(1, 50)) for i in range(10, 0, -1)] + find_optimal_binary_search_tree(nodes) + + +if __name__ == "__main__": + main()