#pragma once
#include <iomanip>
#include <iostream>
using namespace std;
class AVL
{
public:
AVL(){
root = nullptr;
}
~AVL(){
destroy(root);
}
//My Node class for storing data, note how I add height
struct Node{
int data;
Node *left;
Node *right;
int height;
Node(int d){
data = d;
left = nullptr;
right = nullptr;
height = 0;
}
void updateHeight(){
int lHeight = 0;
int rHeight = 0;
if (left != nullptr) {
lHeight = left->height;
}
if (right != nullptr) {
rHeight = right->height;
}
int max = (lHeight > rHeight) ? lHeight : rHeight;
height = max + 1;
}
};
void insert(int val){
insert(val, root);
}
//Rotate a Node branch to the left, in order to balance things
Node* rotateLeft(Node *&leaf){
Node* temp = leaf->right;
leaf->right = temp->left;
temp->left = leaf;
//update the Nodes new height
leaf->updateHeight();
return temp;
}
//Rotate a Node branch to the right, in order to balance things
Node* rotateRight(Node *&leaf){
Node* temp = leaf->left;
leaf->left = temp->right;
temp->right = leaf;
//update the Nodes new height
leaf->updateHeight();
return temp;
}
//Rotate a Node branch to the right then the left, in order to balance things
Node* rotateRightLeft(Node *&leaf){
Node* temp = leaf->right;
leaf->right = rotateRight(temp);
return rotateLeft(leaf);
}
//Rotate a Node branch to the left then the right, in order to balance things
Node* rotateLeftRight(Node *&leaf){
Node* temp = leaf->left;
leaf->left = rotateLeft(temp);
return rotateRight(leaf);
}
//Function that checks each Node's left and right branches to determine if they are unbalanced
//If they are, we rotate the branches
void rebalance(Node *&leaf){
int hDiff = getDiff(leaf);
if (hDiff > 1){
if (getDiff(leaf->left) > 0) {
leaf = rotateRight(leaf);
} else {
leaf = rotateLeftRight(leaf);
}
} else if(hDiff < -1) {
if (getDiff(leaf->right) < 0) {
leaf = rotateLeft(leaf);
} else {
leaf = rotateRightLeft(leaf);
}
}
}
private:
Node *root;
//Insert a Node (very similar to BST, except we need to update Node height and then check for rebalance)
void insert(int d, Node *&leaf){
if (leaf == nullptr){
leaf = new Node(d);
leaf->updateHeight();
}
else {
if (d < leaf->data){
insert(d, leaf->left);
leaf->updateHeight();
rebalance(leaf);
}
else{
insert(d, leaf->right);
leaf->updateHeight();
rebalance(leaf);
}
}
}
//Same as BST
void destroy(Node *&leaf){
if (leaf != nullptr){
destroy(leaf->left);
destroy(leaf->right);
delete leaf;
}
}
//Get the difference between Node right and left branch heights, if it returns positive
//We know the left side is greater, if negative, we know the right side is greater
int getDiff(Node *leaf){
int lHeight = 0;
int rHeight = 0;
if (leaf->left != nullptr) {
lHeight = leaf->left->height;
}
if (leaf->right != nullptr) {
rHeight = leaf->right->height
}
return lHeight - rHeight;
}
};