본문 바로가기

알고리즘 Algorithm/자료구조 Data structure

탐색 (Search) : AVL 트리 (AVL Tree)

원리


이진 트리 탐색 (Binary Search Tree)는 최악의 경우 \(O(n)\)이라는 시간 복잡도를 가집니다. 이를 대비하여 트리를 꾸준하게 리밸런싱(Rebalancing)을 해줘야하고, 그로 인해 나온 트리 중 하나가 AVL Tree입니다.

  • 균형 인수 : 왼쪽 서브 트리의 높이 - 오른쪽 서브 트리의 높이

리밸런싱을 진행할 시 균형인수의 절댓값이 2이상이여야 합니다.

  • LL 회전


    • 왼쪽에 노드가 2개 이상, 즉 균형 인수가 2이상인 경우 LL회전을 통해 균형을 잡습니다.
  • RR 회전


    • 오른쪽에 노드가 2개 이상, 즉 균형 인수가 -2이하인 경우 RR회전을 통해 균형을 잡습니다.
  • LR 회전


    • 왼쪽에 자식 노드가 존재하고 그 자식 노드의 오른쪽에 자식 노드가 존재하는 서브트리, 즉 균형 인수가 왼쪽 서브트리를 기준하여 계산할 때 마이너스일 시 RR회전 후 LL회전을 통해 균형을 잡습니다.
  • RL 회전


    • 오른쪽에 자식 노드가 존재하고 그 자식 노드의 왼쪽에 자식 노드가 존재하는 서브트리, 즉 균형 인수가 오른쪽 서브트리를 기준하여 계산할 때 플러스일 시 LL회전 후 RR회전을 통해 균형을 잡습니다.
  • insert (삽입)

    • root노드의 key값과 비교하여, 서브트리로 이동합니다. 서브트리로 이동 시 각 노드마다 리밸런싱을 진행하여 알맞는 노드의 자리를 찾는 작업을 진행합니다.
  • erase (삭제)

    • 자식 노드가 없을 때

      • target의 부모 노드의 자식 노드를 null로 바꾼 뒤 target을 delete합니다.
    • 자식 노드가 하나일 때

      • target의 부모 노드와 자식 노드를 이어 주고 target을 delete합니다.
    • 자식 노드가 2개일 때

      • target의 오른쪽 서브 트리 중 가장 작은 값을 부모 노드와 연결 후 target을 delete합니다.
  • find (탐색)

    • root노드부터 key됴값을 비교하면서 서브트리로 이동합니다. target값과 같을 시 target의 value의 주소를 반환합니다.

소스 코드 (구현)


Node.h

#pragma once
#include <utility>

using namespace std;

template <typename Key, typename Val>
class Node
{
    template <typename T, typename U>
    friend class AVL;

private:
    pair<Key, Val> node;
    Node<Key, Val> *left = nullptr;
    Node<Key, Val> *right = nullptr;

public:
    virtual ~Node() = default;
    Key GetKey() const { return node.first; }
    Val GetData() const { return node.second; }

private:
    void SetLeft(Node<Key, Val> *const sub)
    {
        if (left != nullptr)
            delete this->left;

        left = sub;
    }

    void SetRight(Node<Key, Val> *const sub)
    {
        if (right != nullptr)
            delete this->right;

        right = sub;
    }

    void ChangeLeft(Node<Key, Val> *const sub) { left = sub; }
    void ChangeRight(Node<Key, Val> *const sub) { right = sub; }

    Node *RemoveLeft()
    {
        Node<Key, Val> *delNode = nullptr;

        if(this != nullptr)
        {
            delNode = left;
            left = nullptr;
        }

        return delNode;
    }

    Node *RemoveRight()
    {
        Node<Key, Val> *delNode = nullptr;

        if(this != nullptr)
        {
            delNode = right;
            right = nullptr;
        }

        return delNode;
    }

    const int GetHeight() const
    {
        if(this == nullptr)
            return 0;

        const int left_height = left->GetHeight();
        const int right_height = right->GetHeight();

        if(left_height > right_height)
            return left_height + 1;
        else
            return right_height + 1;
    }

    const int GetHighDiff() const
    {
        if(this == nullptr)
            return 0;

        const int left_sub_height = left->GetHeight();
        const int right_sub_height = right->GetHeight();

        return left_sub_height - right_sub_height;
    }

    Node *RotateLL()
    {
        Node<Key, Val> *parNode = this;
        Node<Key, Val> *curNode = parNode->left;

        parNode->ChangeLeft(curNode->right);
        curNode->ChangeRight(parNode);

        return curNode;
    }

    Node *RotateRR()
    {
        Node<Key, Val> *parNode = this;
        Node<Key, Val> *curNode = parNode->right;

        parNode->ChangeRight(curNode->left);
        curNode->ChangeLeft(parNode);

        return curNode;
    }

    Node *RotateLR()
    {
        Node<Key, Val> *parNode = this;
        Node<Key, Val> *curNode = parNode->left;

        parNode->ChangeLeft(curNode->RotateRR());

        return parNode->RotateLL();
    }

    Node *RotateRL()
    {
        Node<Key, Val> *parNode = this;
        Node<Key, Val> *curNode = parNode->right;

        parNode->ChangeRight(curNode->RotateLL());

        return parNode->RotateRR();
    }

    void Print()
    {
        if(this == nullptr)
            return;

        if(left != nullptr)
            cout << "left: " << left->GetKey() << ' ';

        if(right != nullptr)
            cout << "right: " << right->GetKey() << ' ';

        cout << endl;
        left->Print();
        right->Print();
    }
};

AVLTree.h

#pragma once
#include <iostream>
#include <utility>
#include "Node.h"

using namespace std;

template <typename Key, typename Val>
class AVL
{
private:
    Node<Key, Val> *root = nullptr;

public:
    ~AVL() = default;
    void insert(const pair<Key, Val> &p)
    {
        insert(root, p);
    }

    void insert(Node<Key, Val> *&pNode, const pair<Key, Val> &p)
    {
        if(pNode == nullptr)
        {
            pNode = new Node<Key, Val>;
            pNode->node = p;
        }

        else if(p.first < pNode->GetKey())
        {
            insert(pNode->left, p);
            pNode = Rebalance(pNode);
        }

        else if(p.first > pNode->GetKey())
        {
            insert(pNode->right, p);
            pNode = Rebalance(pNode);
        }

        else
        {
            cout << "ERROR: Overlap" << endl;
            return;
        }

        return;
    }

    void erase(const Key &target)
    {
        Node<Key, Val> *parNode = nullptr;
        Node<Key, Val> *curNode = root;

        while (curNode != nullptr && curNode->GetKey() != target)
        {
            parNode = curNode;

            if(target < curNode->GetKey())
                curNode = curNode->left;
            else
                curNode = curNode->right;
        }

        if(curNode == nullptr)
        {
            cout << "ERROR: Memory Does Not Exist" << endl;
            exit(-1);
        }

        Node<Key, Val> *delNode = curNode;

        if(delNode->left == nullptr && delNode->right == nullptr) // 단말 노드인 경우
        {
            if(parNode->left == delNode)
                parNode->RemoveLeft();
            else
                parNode->RemoveRight();
        }

        else if(delNode->left == nullptr || delNode->right == nullptr) // 자식 노드가 하나인 경우
        {
            Node<Key, Val> *childNode = nullptr;

            if(delNode->left != nullptr)
                childNode = delNode->left;
            else
                childNode = delNode->right;

            if(parNode->left == delNode)
                parNode->ChangeLeft(childNode);
            else
                parNode->ChangeRight(childNode);
        }

        else // 자식 노드가 2개인 경우
        {
            Node<Key, Val> *rNode = delNode->right;
            Node<Key, Val> *rParNode = delNode;

            while (rNode->left != nullptr)
            {
                rParNode = rNode;
                rNode = rNode->left;
            }

            delNode->node = rNode->node;

            if(rParNode->left == rNode)
                rParNode->ChangeLeft(rNode->left);
            else
                rParNode->ChangeRight(rNode->right);

            delNode = rNode;
        }

        Node<Key, Val> *vRoot = new Node<Key, Val>;
        vRoot->SetRight(root);

        if(vRoot->right != root)
            root = vRoot->right;

        delete vRoot;
        delete delNode;

        root = Rebalance(root);
    }

    const Val *find(const Key &key) const
    {
        Node<Key, Val> *curNode = root;
        Key curKey;

        while (curNode != nullptr)
        {
            curKey = curNode->GetKey();

            if(key == curKey)
                return &curNode->node.second;
            else if(key < curKey)
                curNode = curNode->left;
            else
                curNode = curNode->right;
        }

        return nullptr;
    }

    const Val &operator[](const Key &&key) const { return *find(key); }
    const Val &operator[](const Key &key) const { return *find(key); }

    void show() const
    {
        if(root != nullptr)
            cout << "root: " << root->GetKey() << endl;

        root->Print();
    }

private:
    Node<Key, Val> *Rebalance(Node<Key, Val> *pNode)
    {
        const int dif = pNode->GetHighDiff();

        if(dif > 1)
        {
            if(pNode->left->GetHighDiff() > 0)
                pNode = pNode->RotateLL();
            else
                pNode = pNode->RotateLR();
        }

        else if(dif < -1)
        {
            if(pNode->right->GetHighDiff() < 0)
                pNode = pNode->RotateRR();
            else
                pNode = pNode->RotateRL();
        }

        return pNode;
    }
};

AVLTree.cpp

#include <iostream>
#include <utility>
#include "AVLTree.h"

using namespace std;

int main()
{
    AVL<int, char> a;

    for (int i = 0; i < 9; i++)
    {
        a.insert(make_pair(i + 1, 'a' + i));
    }

    a.show();

    return 0;
}