Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: solution for optimal BST in Python #333

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions Python/optimal_binary_search_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# This Python program implements an optimal binary search tree (abbreviated BST)
# building dynamic programming algorithm that delivers O(n^2) performance.
#
# The goal of the optimal BST problem is to build a low-cost BST for a
# given set of nodes, each with its own key and frequency. The frequency
# of the node is defined as how many time the node is being searched.
# The search cost of binary search tree is given by this formula:
#
# cost(1, n) = sum{i = 1 to n}((depth(node_i) + 1) * node_i_freq)
#
# where n is number of nodes in the BST. The characteristic of low-cost
# BSTs is having a faster overall search time than other implementations.
# The reason for their fast search time is that the nodes with high
# frequencies will be placed near the root of the tree while the nodes
# with low frequencies will be placed near the leaves of the tree thus
# reducing search time in the most frequent instances.
import sys
from random import randint


class Node:
"""Binary Search Tree Node"""

def __init__(self, key, freq):
self.key = key
self.freq = freq

def __str__(self):
"""
>>> str(Node(1, 2))
'Node(key=1, freq=2)'
"""
return f"Node(key={self.key}, freq={self.freq})"


def print_binary_search_tree(root, key, i, j, parent, is_left):
"""
Recursive function to print a BST from a root table.

>>> key = [3, 8, 9, 10, 17, 21]
>>> root = [[0, 1, 1, 1, 1, 1], [0, 1, 1, 1, 1, 3], [0, 0, 2, 3, 3, 3], \
[0, 0, 0, 3, 3, 3], [0, 0, 0, 0, 4, 5], [0, 0, 0, 0, 0, 5]]
>>> print_binary_search_tree(root, key, 0, 5, -1, False)
8 is the root of the binary search tree.
3 is the left child of key 8.
10 is the right child of key 8.
9 is the left child of key 10.
21 is the right child of key 10.
17 is the left child of key 21.
"""
if i > j or i < 0 or j > len(root) - 1:
return

node = root[i][j]
if parent == -1: # root does not have a parent
print(f"{key[node]} is the root of the binary search tree.")
elif is_left:
print(f"{key[node]} is the left child of key {parent}.")
else:
print(f"{key[node]} is the right child of key {parent}.")

print_binary_search_tree(root, key, i, node - 1, key[node], True)
print_binary_search_tree(root, key, node + 1, j, key[node], False)


def find_optimal_binary_search_tree(nodes):
"""
This function calculates and prints the optimal binary search tree.
The dynamic programming algorithm below runs in O(n^2) time.
Implemented from CLRS (Introduction to Algorithms) book.
https://en.wikipedia.org/wiki/Introduction_to_Algorithms

>>> find_optimal_binary_search_tree([Node(12, 8), Node(10, 34), Node(20, 50), \
Node(42, 3), Node(25, 40), Node(37, 30)])
Binary search tree nodes:
Node(key=10, freq=34)
Node(key=12, freq=8)
Node(key=20, freq=50)
Node(key=25, freq=40)
Node(key=37, freq=30)
Node(key=42, freq=3)
<BLANKLINE>
The cost of optimal BST for given tree nodes is 324.
20 is the root of the binary search tree.
10 is the left child of key 20.
12 is the right child of key 10.
25 is the right child of key 20.
37 is the right child of key 25.
42 is the right child of key 37.
"""
# Tree nodes must be sorted first, the code below sorts the keys in
# increasing order and rearrange its frequencies accordingly.
nodes.sort(key=lambda node: node.key)

n = len(nodes)

keys = [nodes[i].key for i in range(n)]
freqs = [nodes[i].freq for i in range(n)]

# This 2D array stores the overall tree cost (which's as minimized as possible);
# for a single key, cost is equal to frequency of the key.
dp = [[freqs[i] if i == j else 0 for j in range(n)] for i in range(n)]
# sum[i][j] stores the sum of key frequencies between i and j inclusive in nodes
# array
total = [[freqs[i] if i == j else 0 for j in range(n)] for i in range(n)]
# stores tree roots that will be used later for constructing binary search tree
root = [[i if i == j else 0 for j in range(n)] for i in range(n)]

for interval_length in range(2, n + 1):
for i in range(n - interval_length + 1):
j = i + interval_length - 1

dp[i][j] = sys.maxsize # set the value to "infinity"
total[i][j] = total[i][j - 1] + freqs[j]

# Apply Knuth's optimization
# Loop without optimization: for r in range(i, j + 1):
for r in range(root[i][j - 1], root[i + 1][j] + 1): # r is a temporal root
left = dp[i][r - 1] if r != i else 0 # optimal cost for left subtree
right = dp[r + 1][j] if r != j else 0 # optimal cost for right subtree
cost = left + total[i][j] + right

if dp[i][j] > cost:
dp[i][j] = cost
root[i][j] = r

print("Binary search tree nodes:")
for node in nodes:
print(node)

print(f"\nThe cost of optimal BST for given tree nodes is {dp[0][n - 1]}.")
print_binary_search_tree(root, keys, 0, n - 1, -1, False)


def main():
# A sample binary search tree
nodes = [Node(i, randint(1, 50)) for i in range(10, 0, -1)]
find_optimal_binary_search_tree(nodes)


if __name__ == "__main__":
main()