원리
이진 트리 탐색 (Binary Search Tree)는 최악의 경우 \(O(n)\)이라는 시간 복잡도를 가집니다. 이를 대비하여 트리를 꾸준하게 리밸런싱(Rebalancing)을 해줘야하고, 그로 인해 나온 트리 중 하나가 AVL Tree입니다.
- 균형 인수 : 왼쪽 서브 트리의 높이 - 오른쪽 서브 트리의 높이
리밸런싱을 진행할 시 균형인수의 절댓값이 2이상이여야 합니다.
LL 회전
- 왼쪽에 노드가 2개 이상, 즉 균형 인수가 2이상인 경우 LL회전을 통해 균형을 잡습니다.
RR 회전
- 오른쪽에 노드가 2개 이상, 즉 균형 인수가 -2이하인 경우 RR회전을 통해 균형을 잡습니다.
LR 회전
- 왼쪽에 자식 노드가 존재하고 그 자식 노드의 오른쪽에 자식 노드가 존재하는 서브트리, 즉 균형 인수가 왼쪽 서브트리를 기준하여 계산할 때 마이너스일 시 RR회전 후 LL회전을 통해 균형을 잡습니다.
RL 회전
- 오른쪽에 자식 노드가 존재하고 그 자식 노드의 왼쪽에 자식 노드가 존재하는 서브트리, 즉 균형 인수가 오른쪽 서브트리를 기준하여 계산할 때 플러스일 시 LL회전 후 RR회전을 통해 균형을 잡습니다.
insert (삽입)
- root노드의 key값과 비교하여, 서브트리로 이동합니다. 서브트리로 이동 시 각 노드마다 리밸런싱을 진행하여 알맞는 노드의 자리를 찾는 작업을 진행합니다.
erase (삭제)
자식 노드가 없을 때
- target의 부모 노드의 자식 노드를 null로 바꾼 뒤 target을 delete합니다.
자식 노드가 하나일 때
- target의 부모 노드와 자식 노드를 이어 주고 target을 delete합니다.
자식 노드가 2개일 때
- target의 오른쪽 서브 트리 중 가장 작은 값을 부모 노드와 연결 후 target을 delete합니다.
find (탐색)
- root노드부터 key됴값을 비교하면서 서브트리로 이동합니다. target값과 같을 시 target의 value의 주소를 반환합니다.
소스 코드 (구현)
Node.h
#pragma once
#include <utility>
using namespace std;
template <typename Key, typename Val>
class Node
{
template <typename T, typename U>
friend class AVL;
private:
pair<Key, Val> node;
Node<Key, Val> *left = nullptr;
Node<Key, Val> *right = nullptr;
public:
virtual ~Node() = default;
Key GetKey() const { return node.first; }
Val GetData() const { return node.second; }
private:
void SetLeft(Node<Key, Val> *const sub)
{
if (left != nullptr)
delete this->left;
left = sub;
}
void SetRight(Node<Key, Val> *const sub)
{
if (right != nullptr)
delete this->right;
right = sub;
}
void ChangeLeft(Node<Key, Val> *const sub) { left = sub; }
void ChangeRight(Node<Key, Val> *const sub) { right = sub; }
Node *RemoveLeft()
{
Node<Key, Val> *delNode = nullptr;
if(this != nullptr)
{
delNode = left;
left = nullptr;
}
return delNode;
}
Node *RemoveRight()
{
Node<Key, Val> *delNode = nullptr;
if(this != nullptr)
{
delNode = right;
right = nullptr;
}
return delNode;
}
const int GetHeight() const
{
if(this == nullptr)
return 0;
const int left_height = left->GetHeight();
const int right_height = right->GetHeight();
if(left_height > right_height)
return left_height + 1;
else
return right_height + 1;
}
const int GetHighDiff() const
{
if(this == nullptr)
return 0;
const int left_sub_height = left->GetHeight();
const int right_sub_height = right->GetHeight();
return left_sub_height - right_sub_height;
}
Node *RotateLL()
{
Node<Key, Val> *parNode = this;
Node<Key, Val> *curNode = parNode->left;
parNode->ChangeLeft(curNode->right);
curNode->ChangeRight(parNode);
return curNode;
}
Node *RotateRR()
{
Node<Key, Val> *parNode = this;
Node<Key, Val> *curNode = parNode->right;
parNode->ChangeRight(curNode->left);
curNode->ChangeLeft(parNode);
return curNode;
}
Node *RotateLR()
{
Node<Key, Val> *parNode = this;
Node<Key, Val> *curNode = parNode->left;
parNode->ChangeLeft(curNode->RotateRR());
return parNode->RotateLL();
}
Node *RotateRL()
{
Node<Key, Val> *parNode = this;
Node<Key, Val> *curNode = parNode->right;
parNode->ChangeRight(curNode->RotateLL());
return parNode->RotateRR();
}
void Print()
{
if(this == nullptr)
return;
if(left != nullptr)
cout << "left: " << left->GetKey() << ' ';
if(right != nullptr)
cout << "right: " << right->GetKey() << ' ';
cout << endl;
left->Print();
right->Print();
}
};
AVLTree.h
#pragma once
#include <iostream>
#include <utility>
#include "Node.h"
using namespace std;
template <typename Key, typename Val>
class AVL
{
private:
Node<Key, Val> *root = nullptr;
public:
~AVL() = default;
void insert(const pair<Key, Val> &p)
{
insert(root, p);
}
void insert(Node<Key, Val> *&pNode, const pair<Key, Val> &p)
{
if(pNode == nullptr)
{
pNode = new Node<Key, Val>;
pNode->node = p;
}
else if(p.first < pNode->GetKey())
{
insert(pNode->left, p);
pNode = Rebalance(pNode);
}
else if(p.first > pNode->GetKey())
{
insert(pNode->right, p);
pNode = Rebalance(pNode);
}
else
{
cout << "ERROR: Overlap" << endl;
return;
}
return;
}
void erase(const Key &target)
{
Node<Key, Val> *parNode = nullptr;
Node<Key, Val> *curNode = root;
while (curNode != nullptr && curNode->GetKey() != target)
{
parNode = curNode;
if(target < curNode->GetKey())
curNode = curNode->left;
else
curNode = curNode->right;
}
if(curNode == nullptr)
{
cout << "ERROR: Memory Does Not Exist" << endl;
exit(-1);
}
Node<Key, Val> *delNode = curNode;
if(delNode->left == nullptr && delNode->right == nullptr) // 단말 노드인 경우
{
if(parNode->left == delNode)
parNode->RemoveLeft();
else
parNode->RemoveRight();
}
else if(delNode->left == nullptr || delNode->right == nullptr) // 자식 노드가 하나인 경우
{
Node<Key, Val> *childNode = nullptr;
if(delNode->left != nullptr)
childNode = delNode->left;
else
childNode = delNode->right;
if(parNode->left == delNode)
parNode->ChangeLeft(childNode);
else
parNode->ChangeRight(childNode);
}
else // 자식 노드가 2개인 경우
{
Node<Key, Val> *rNode = delNode->right;
Node<Key, Val> *rParNode = delNode;
while (rNode->left != nullptr)
{
rParNode = rNode;
rNode = rNode->left;
}
delNode->node = rNode->node;
if(rParNode->left == rNode)
rParNode->ChangeLeft(rNode->left);
else
rParNode->ChangeRight(rNode->right);
delNode = rNode;
}
Node<Key, Val> *vRoot = new Node<Key, Val>;
vRoot->SetRight(root);
if(vRoot->right != root)
root = vRoot->right;
delete vRoot;
delete delNode;
root = Rebalance(root);
}
const Val *find(const Key &key) const
{
Node<Key, Val> *curNode = root;
Key curKey;
while (curNode != nullptr)
{
curKey = curNode->GetKey();
if(key == curKey)
return &curNode->node.second;
else if(key < curKey)
curNode = curNode->left;
else
curNode = curNode->right;
}
return nullptr;
}
const Val &operator[](const Key &&key) const { return *find(key); }
const Val &operator[](const Key &key) const { return *find(key); }
void show() const
{
if(root != nullptr)
cout << "root: " << root->GetKey() << endl;
root->Print();
}
private:
Node<Key, Val> *Rebalance(Node<Key, Val> *pNode)
{
const int dif = pNode->GetHighDiff();
if(dif > 1)
{
if(pNode->left->GetHighDiff() > 0)
pNode = pNode->RotateLL();
else
pNode = pNode->RotateLR();
}
else if(dif < -1)
{
if(pNode->right->GetHighDiff() < 0)
pNode = pNode->RotateRR();
else
pNode = pNode->RotateRL();
}
return pNode;
}
};
AVLTree.cpp
#include <iostream>
#include <utility>
#include "AVLTree.h"
using namespace std;
int main()
{
AVL<int, char> a;
for (int i = 0; i < 9; i++)
{
a.insert(make_pair(i + 1, 'a' + i));
}
a.show();
return 0;
}
'알고리즘 Algorithm > 자료구조 Data structure' 카테고리의 다른 글
테이블 (Table) : 체이닝 (Chaining) (0) | 2021.03.25 |
---|---|
테이블 (Table) : Key & Value, Hash (해쉬), Collision (충돌) (0) | 2021.03.24 |
탐색 (Search) : 이진 탐색 트리 (Binary Search Tree) (0) | 2021.03.11 |
탐색 (Search) : 보간 탐색 (Interpolation Search) (0) | 2021.03.09 |
정렬 (sort) : 기수 정렬 (Radix sort) (0) | 2021.03.02 |