Here is a really basic binary tree class, it just includes the basics of creating, inserting, erasing, and returning size. In later posts I will talk about printing and traversals.
This class also uses the Node.h talked about in this earlier post. You’ll notice that I really like to use recursion, I think this is cleaner than looping.
#include <assert.h>
#include "Node.h"
using namespace std;
class Bst
{
public:
//constructor for when a head Node is provided and when it is not
Bst() {
root = nullptr;
}
Bst(Node *np) {
root = np;
}
//destroy the tree, we need to go through and destroy each node
~Bst() {
destroyTree(root);
}
//get the number of nodes in the tree
int size() {
return size(root);
}
//erase a value in the tree
void erase(int item) {
erase(item, root);
}
//insert a Node in the tree
void insert(int item) {
insert(item, root);
}
private:
Node* root;
//Go through each branch and recursively destroy all Nodes
void destroyTree(Node*& n) {
if (n != nullptr) {
destroyTree(n->left);
destroyTree(n->right);
delete n;
}
}
//For each Node return the number of left and right nodes
//Add it up recursively to get the total size
int size(Node* n) {
if (n != nullptr) {
int left = size(n->left);
int right = size(n->right);
int self = 1;
return left + self + right;
}
return 0;
}
//Find the minimum Node value
Node* findMin(Node* n){
assert(n != nullptr);
if (n->left != nullptr) {
return findMin(n->left);
}
return n;
}
//this one is a beast
//look through all the nodes recursively
//once you find the node value there are numerous cases we need to look for
//If the current node does not have left and right nodes, just delete it
//If it does have a left or right node, set the child to the parent
//If it has both left and right, we need to work some magic. First we find
//the smallest value and set the node we want to delete to that value (removing it)
void erase(int item, Node*& n) {
if (n != nullptr) {
if (item == n->data) {
if (n->right == nullptr && n->left == nullptr) {
delete n;
n = nullptr;
} else if (n->right == nullptr) {
Node* temp = n;
n = n->left;
delete n;
} else if (n->left == nullptr){
Node* temp = n;
n = n->right;
delete n;
} else {
Node *temp = findMin(n->right);
n->data = temp->data;
erase(item, n->right);
}
} else if (item < n->data) {
erase(item, n->left);
} else {
erase(item, n->right);
}
}
}
//look through all the nodes
//insert the node on the correct node, it will be added to the left if the value is less
//added to the right if the value is greater
void insert(int item, Node*& n) {
if (n != nullptr) {
if (item < n->data) {
insert(item, n->left);
} else {
insert(item, n->right);
}
} else {
n = new Node(item);
}
}
};
Let me know if you have any improvements or comments!