Improved del_node func
There was some logical error in implementation of delete node function.
Also we don't need to find balance factor 2 times so made separate variable.
Describe your change:
- [x] Add an algorithm?
- [x] Fix a bug or typo in an existing algorithm?
- [ ] Add or change doctests? -- Note: Please avoid changing both code and tests in a single pull request.
- [ ] Documentation change?
Checklist:
- [x] I have read CONTRIBUTING.md.
- [x] This pull request is all my own work -- I have not plagiarized.
- [x] I know that pull requests will not be merged if they fail the automated tests.
- [x] This PR only changes one algorithm file. To ease review, please open separate PRs for separate algorithms.
- [x] All new Python files are placed inside an existing directory.
- [x] All filenames are in all lowercase characters with no spaces or dashes.
- [x] All functions and variable names follow Python naming conventions.
- [x] All function parameters and return values are annotated with Python type hints.
- [x] All functions have doctests that pass the automated testing.
- [x] All new algorithms include at least one URL that points to Wikipedia or another similar explanation.
- [x] If this pull request resolves one or more open issues then the description above includes the issue number(s) with a closing keyword: "Fixes #ISSUE-NUMBER".
If you run python avl_tree.py directly, it will crash. This suggests that the tests are insufficient.
insert:3
3
*************************************
insert:7
3
* 7
*************************************
insert:1
3
1 7
*************************************
insert:2
3
1 7
* 2 * *
*************************************
insert:5
3
1 7
* 2 5 *
*************************************
insert:4
left rotation node: 7
3
1 5
* 2 4 7
*************************************
insert:6
3
1 5
* 2 4 7
* * * * * * 6 *
*************************************
insert:0
3
1 5
0 2 4 7
* * * * * * 6 *
*************************************
insert:9
3
1 5
0 2 4 7
* * * * * * 6 9
*************************************
insert:8
right rotation node: 5
3
1 7
0 2 5 9
* * * * 4 6 8 *
*************************************
delete:2
3
1 7
0 * 5 9
* * * * 4 6 8 *
*************************************
delete:5
3
1 7
0 * 6 9
* * * * 4 * 8 *
*************************************
delete:9
3
1 7
0 * 6 8
* * * * 4 * * *
*************************************
delete:8
Traceback (most recent call last):
File "/tmp/avl_tree.py", line 354, in <module>
t.del_node(i)
File "/tmp/avl_tree.py", line 300, in del_node
self.root = del_node(self.root, data)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/avl_tree.py", line 227, in del_node
root.set_right(del_node(right_child, data))
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/avl_tree.py", line 240, in del_node
root = rl_rotation(root)
^^^^^^^^^^^^^^^^^
File "/tmp/avl_tree.py", line 145, in rl_rotation
assert right_child is not None
^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
@99991 Please review my recent Pull Request
The tree becomes unbalanced when inserting and deleting the same value multiple times.
Here are a few more tests you can use for debugging. Note that the tests are very comprehensive and probably not suited for doctests because that would take too long.
def dump_tree_inorder(node, data):
if node:
dump_tree_inorder(node.left, data)
data.append(node.data)
dump_tree_inorder(node.right, data)
def check_balance(node):
if node:
if node.left and node.right:
balance = node.left.height - node.right.height
assert abs(balance) <= 1, f"Height difference of more than 1 between left node {node.left.data} with height {node.left.height} and right node {node.right.data} with height {node.right.height} of parent node with data {node.data}"
left_height = check_balance(node.left)
right_height = check_balance(node.right)
height = max(left_height, right_height) + 1
assert height == node.height, f"Node {node.data} should have height {height} instead of height {node.height}"
return height
return 0
replay = """
insert:6
insert:1
insert:1
right rotation node: 1
left rotation node: 6
insert:8
insert:7
left rotation node: 8
right rotation node: 6
insert:9
right rotation node: 1
insert:9
right rotation node: 8
insert:6
insert:2
insert:2
left rotation node: 6
right rotation node: 1
insert:1
insert:0
left rotation node: 7
insert:5
insert:8
insert:8
right rotation node: 8
insert:4
left rotation node: 5
right rotation node: 2
delete:1
left rotation node: 1
left rotation node: 7
right rotation node: 2
"""
t = AVLtree()
# Parse printed output to reproduce error
for line in replay.strip().split("\n"):
op, value = line.split(":")
if op == "insert":
t.insert(value)
elif op == "delete":
t.del_node(value)
print(t)
check_balance(t.root)
import random
# Brutefoce test many insertion and deletion sequences
for length in range(1000):
for seed in range(1000):
print("seed:", seed, "length", length)
random.seed(seed)
t = AVLtree()
s = []
for _ in range(length):
if random.random() < 0.4 and s:
value = random.choice(s)
i = s.index(value)
s.pop(i)
t.del_node(value)
else:
value = random.randrange(20)
if value in s: continue
s.append(value)
t.insert(value)
check_balance(t.root)
tree_data = []
dump_tree_inorder(t.root, tree_data)
assert sorted(s) == tree_data, f"Expected {sorted(s)}, got {sorted(tree_data)}"
print()
@99991 AVL must not contain duplicate nodes by definition(we can maintain count variable though)
So what am I supposed is either I maintain a new variable to count and store count of duplicates(also change other functions) or discard duplicate node while inserting them(only changes in insert function) ?
So what am I supposed is either I maintain a new variable to count and store count of duplicates(also change other functions) or discard duplicate node while inserting them(only changes in insert function) ?
Both solutions are good (much better than silently corrupting the tree anyway). I think ignoring duplicate keys is easier since it only requires an additional if-branch in insert_node.
@99991 Kindly review my pull request (like 3rd time)
The following insertion/deletion sequence will result in an unbalanced tree. The height of the left subtree of node 12 is 1, while the height of the right subtree is 3, which results in a larger-than-allowed height difference of 2.
insert:1
insert:5
insert:6
insert:12
insert:3
insert:9
insert:19
insert:0
insert:16
insert:4
insert:10
delete:16
insert:14
insert:17
insert:15
insert:7
delete:0
You can use this code to check:
def dump_tree_inorder(node, data):
if node:
dump_tree_inorder(node.left, data)
data.append(node.data)
dump_tree_inorder(node.right, data)
def check_balance(node):
if node:
if node.left and node.right:
balance = node.left.height - node.right.height
assert abs(balance) <= 1, f"Height difference of more than 1 between left node {node.left.data} with height {node.left.height} and right node {node.right.data} with height {node.right.height} of parent node with data {node.data}"
left_height = check_balance(node.left)
right_height = check_balance(node.right)
height = max(left_height, right_height) + 1
assert height == node.height, f"Node {node.data} should have height {height} instead of height {node.height}"
return height
return 0
replay = """
insert:1
insert:5
insert:6
insert:12
insert:3
insert:9
insert:19
insert:0
insert:16
insert:4
insert:10
delete:16
insert:14
insert:17
insert:15
insert:7
delete:0
"""
t = AVLtree()
# Parse printed output to reproduce error
for line in replay.strip().split("\n"):
op, value = line.split(":")
try:
value = int(value.strip())
except ValueError:
continue
if op == "insert":
t.insert(value)
elif op == "delete":
t.del_node(value)
print(t)
check_balance(t.root)
import random
# Brutefoce test many insertion and deletion sequences
for length in range(1000):
for seed in range(1000):
print("seed:", seed, "length", length)
random.seed(seed * 1000 + length)
t = AVLtree()
s = set()
for _ in range(length):
if random.random() < 0.4 and s:
value = random.choice(list(s))
s.remove(value)
t.del_node(value)
else:
value = random.randrange(20)
if value in s: continue
s.add(value)
t.insert(value)
check_balance(t.root)
tree_data = []
dump_tree_inorder(t.root, tree_data)
assert sorted(s) == tree_data, f"Expected {sorted(s)}, got {sorted(tree_data)}"
print()
Output I am getting after running that replay thing
insert:2
2
*************************************
insert:7
2
* 7
*************************************
insert:6
left rotation node: 7
right rotation node: 2
6
2 7
*************************************
insert:0
6
2 7
0 * * *
*************************************
insert:4
6
2 7
0 4 * *
*************************************
insert:9
6
2 7
0 4 * 9
*************************************
insert:8
left rotation node: 9
right rotation node: 7
6
2 8
0 4 7 9
*************************************
insert:1
6
2 8
0 4 7 9
* 1 * * * * * *
*************************************
insert:5
6
2 8
0 4 7 9
* 1 * 5 * * * *
*************************************
insert:3
6
2 8
0 4 7 9
* 1 3 5 * * * *
*************************************
delete:8
6
2 9
0 4 7 *
* 1 3 5 * * * *
*************************************
delete:2
6
3 9
0 4 7 *
* 1 * 5 * * * *
*************************************
delete:3
6
4 9
0 5 7 *
* 1 * * * * * *
*************************************
delete:4
right rotation node: 0
left rotation node: 5
6
1 9
0 5 7 *
*************************************
delete:9
6
1 7
0 5 * *
*************************************
delete:0
6
1 7
* 5 * *
*************************************
delete:7
right rotation node: 1
left rotation node: 6
5
1 6
*************************************
delete:5
6
1 *
*************************************
delete:6
1
*************************************
delete:1
insert:1
1
*************************************
insert:5
1
* 5
*************************************
insert:6
right rotation node: 1
5
1 6
*************************************
insert:12
5
1 6
* * * 12
*************************************
insert:3
5
1 6
* 3 * 12
*************************************
insert:9
left rotation node: 12
right rotation node: 6
5
1 9
* 3 6 12
*************************************
insert:19
5
1 9
* 3 6 12
* * * * * * * 19
*************************************
insert:0
5
1 9
0 3 6 12
* * * * * * * 19
*************************************
insert:16
left rotation node: 19
right rotation node: 12
5
1 9
0 3 6 16
* * * * * * 12 19
*************************************
insert:4
5
1 9
0 3 6 16
* * * 4 * * 12 19
*************************************
insert:10
left rotation node: 16
right rotation node: 9
5
1 12
0 3 9 16
* * * 4 6 10 * 19
*************************************
delete:16
5
1 12
0 3 9 19
* * * 4 6 10 * *
*************************************
insert:14
5
1 12
0 3 9 19
* * * 4 6 10 14 *
*************************************
insert:17
right rotation node: 14
left rotation node: 19
5
1 12
0 3 9 17
* * * 4 6 10 14 19
*************************************
insert:15
5
1 12
0 3 9 17
* * * 4 6 10 14 19
* * * * * * * * * * * * * 15 * *
*************************************
insert:7
5
1 12
0 3 9 17
* * * 4 6 10 14 19
* * * * * * * * * 7 * * * 15 * *
*************************************
delete:0
right rotation node: 1
right rotation node: 5
12
5 17
3 9 14 19
1 4 6 10 * 15 * *
* * * * * 7 * * * * * * * * * *
*************************************
It is correct I think so Please review the recent commit
Run this code below
"""
Implementation of an auto-balanced binary tree!
For doctests run following command:
python3 -m doctest -v avl_tree.py
For testing run:
python avl_tree.py
"""
from __future__ import annotations
import math
import random
from typing import Any
class MyQueue:
def __init__(self) -> None:
self.data: list[Any] = []
self.head: int = 0
self.tail: int = 0
def is_empty(self) -> bool:
return self.head == self.tail
def push(self, data: Any) -> None:
self.data.append(data)
self.tail = self.tail + 1
def pop(self) -> Any:
ret = self.data[self.head]
self.head = self.head + 1
return ret
def count(self) -> int:
return self.tail - self.head
def print_queue(self) -> None:
print(self.data)
print("**************")
print(self.data[self.head : self.tail])
class MyNode:
def __init__(self, data: Any) -> None:
self.data = data
self.left: MyNode | None = None
self.right: MyNode | None = None
self.height: int = 1
def get_data(self) -> Any:
return self.data
def get_left(self) -> MyNode | None:
return self.left
def get_right(self) -> MyNode | None:
return self.right
def get_height(self) -> int:
return self.height
def set_data(self, data: Any) -> None:
self.data = data
def set_left(self, node: MyNode | None) -> None:
self.left = node
def set_right(self, node: MyNode | None) -> None:
self.right = node
def set_height(self, height: int) -> None:
self.height = height
def get_height(node: MyNode | None) -> int:
if node is None:
return 0
return node.get_height()
def my_max(a: int, b: int) -> int:
if a > b:
return a
return b
def right_rotation(node: MyNode) -> MyNode:
r"""
A B
/ \ / \
B C Bl A
/ \ --> / / \
Bl Br UB Br C
/
UB
UB = unbalanced node
"""
print("left rotation node:", node.get_data())
ret = node.get_left()
assert ret is not None
node.set_left(ret.get_right())
ret.set_right(node)
h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1
node.set_height(h1)
h2 = my_max(get_height(ret.get_right()), get_height(ret.get_left())) + 1
ret.set_height(h2)
return ret
def left_rotation(node: MyNode) -> MyNode:
"""
a mirror symmetry rotation of the left_rotation
"""
print("right rotation node:", node.get_data())
ret = node.get_right()
assert ret is not None
node.set_right(ret.get_left())
ret.set_left(node)
h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1
node.set_height(h1)
h2 = my_max(get_height(ret.get_right()), get_height(ret.get_left())) + 1
ret.set_height(h2)
return ret
def lr_rotation(node: MyNode) -> MyNode:
r"""
A A Br
/ \ / \ / \
B C LR Br C RR B A
/ \ --> / \ --> / / \
Bl Br B UB Bl UB C
\ /
UB Bl
RR = right_rotation LR = left_rotation
"""
left_child = node.get_left()
assert left_child is not None
node.set_left(left_rotation(left_child))
return right_rotation(node)
def rl_rotation(node: MyNode) -> MyNode:
right_child = node.get_right()
assert right_child is not None
node.set_right(right_rotation(right_child))
return left_rotation(node)
def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
if node is None:
return MyNode(data)
if data == node.get_data():
return node
if data < node.get_data():
node.set_left(insert_node(node.get_left(), data))
if (
get_height(node.get_left()) - get_height(node.get_right()) == 2
): # an unbalance detected
left_child = node.get_left()
assert left_child is not None
if (
data < left_child.get_data()
): # new node is the left child of the left child
node = right_rotation(node)
else:
node = lr_rotation(node)
else:
node.set_right(insert_node(node.get_right(), data))
if get_height(node.get_right()) - get_height(node.get_left()) == 2:
right_child = node.get_right()
assert right_child is not None
if data < right_child.get_data():
node = rl_rotation(node)
else:
node = left_rotation(node)
h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1
node.set_height(h1)
return node
def get_right_most(root: MyNode) -> Any:
while True:
right_child = root.get_right()
if right_child is None:
break
root = right_child
return root.get_data()
def get_left_most(root: MyNode) -> Any:
while True:
left_child = root.get_left()
if left_child is None:
break
root = left_child
return root.get_data()
def get_balance(node: MyNode | None) -> int:
if node is None:
return 0
return get_height(node.get_left()) - get_height(node.get_right())
def get_min_value_node(node: MyNode) -> MyNode:
# Returns the node with the minimum value in the tree that is leftmost node
# Function get_left_most is not used here because it returns the value of the node
while True:
left_child = node.get_left()
if left_child is None:
break
node = left_child
return node
def del_node(root: MyNode | None, data: Any) -> MyNode | None:
if root is None:
print(f"{data} not found in the tree")
return None
if root.get_data() > data:
left_child = del_node(root.get_left(), data)
root.set_left(left_child)
elif root.get_data() < data:
right_child = del_node(root.get_right(), data)
root.set_right(right_child)
else:
if root.get_left() is None:
return root.get_right()
elif root.get_right() is None:
return root.get_left()
right_child = root.get_right()
assert right_child is not None
temp = get_min_value_node(right_child)
root.set_data(temp.get_data())
root.set_right(del_node(root.get_right(), temp.get_data()))
root.set_height(
1 + my_max(get_height(root.get_left()), get_height(root.get_right()))
)
balance = get_balance(root)
if balance > 1:
left_child = root.get_left()
assert left_child is not None
if get_balance(left_child) >= 0:
return right_rotation(root)
root.set_left(left_rotation(left_child))
return right_rotation(root)
if balance < -1:
right_child = root.get_right()
assert right_child is not None
if get_balance(right_child) <= 0:
return left_rotation(root)
root.set_right(right_rotation(right_child))
return left_rotation(root)
return root
class AVLtree:
"""
An AVL tree doctest
Examples:
>>> t = AVLtree()
>>> t.insert(4)
insert:4
>>> print(str(t).replace(" \\n","\\n"))
4
*************************************
>>> t.insert(2)
insert:2
>>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n"))
4
2 *
*************************************
>>> t.insert(3)
insert:3
right rotation node: 2
left rotation node: 4
>>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n"))
3
2 4
*************************************
>>> t.get_height()
2
>>> t.delete(3)
delete:3
>>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n"))
4
2 *
*************************************
"""
def __init__(self) -> None:
self.root: MyNode | None = None
def get_height(self) -> int:
return get_height(self.root)
def insert(self, data: Any) -> None:
print("insert:" + str(data))
self.root = insert_node(self.root, data)
def delete(self, data: Any) -> None:
print("delete:" + str(data))
if self.root is None:
print("Tree is empty!")
return
self.root = del_node(self.root, data)
def __str__(
self,
) -> str: # a level traversale, gives a more intuitive look on the tree
output = ""
q = MyQueue()
q.push(self.root)
layer = self.get_height()
if layer == 0:
return output
cnt = 0
while not q.is_empty():
node = q.pop()
space = " " * int(math.pow(2, layer - 1))
output += space
if node is None:
output += "*"
q.push(None)
q.push(None)
else:
output += str(node.get_data())
q.push(node.get_left())
q.push(node.get_right())
output += space
cnt = cnt + 1
for i in range(100):
if cnt == math.pow(2, i) - 1:
layer = layer - 1
if layer == 0:
output += "\n*************************************"
return output
output += "\n"
break
output += "\n*************************************"
return output
def _test() -> None:
import doctest
doctest.testmod()
if __name__ == "__main__":
_test()
t = AVLtree()
lst = list(range(10))
random.shuffle(lst)
for i in lst:
t.insert(i)
print(str(t))
random.shuffle(lst)
for i in lst:
t.delete(i)
print(str(t))
LGTM :+1:
(There might be a small issue when inserting the value NaN, because comparison with NaN is always false, but lets ignore that.)
LGTM 👍
(There might be a small issue when inserting the value
NaN, because comparison withNaNis always false, but lets ignore that.)
The pull request has been reviewed. Could you please merge it?
Could you please merge it?
I can not merge this PR because I do not have write access to this repository. You can go through recently merged PRs to check who approved/merged them and ask if they can also merge yours: https://github.com/TheAlgorithms/Python/pulls?q=is%3Apr+is%3Amerged
@cclauss The pull request have been reviewed. Could you please merge it?
@tianyizheng02 The pull request have been reviewed. Could you please merge it?