说是234树同构,待确认

  1template<typename Key, typename Value = Key>
  2class RBTree {
  3public:
  4    RBTree(const RBTree&) = delete;
  5    RBTree& operator=(const RBTree&) = delete;
  6    RBTree() :
  7        root(nullptr), nil(nullptr), cnt(0)
  8    {
  9        nil = new Node();
 10        nil->c = BLACK;
 11        nil->l = nil->r = nil->p = nil;
 12        root = nil;
 13    }
 14    ~RBTree()
 15    {
 16        clear_node(root);
 17        delete nil;
 18    }
 19
 20    void insert(const Key& k, const Value& v = Value())
 21    {
 22        Node* node = new Node(k, v);
 23        node->l = node->r = node->p = nil;
 24        Node* parent = nil;
 25        for (Node* cur = root; cur != nil;) {
 26            parent = cur;
 27            if (node->key == cur->key) {
 28                cur->val = v;
 29                delete node;
 30                return;
 31            }
 32            cur = (node->key < cur->key) ? cur->l : cur->r;
 33        }
 34        node->p = parent;
 35        if (parent == nil) root = node;
 36        else if (node->key < parent->key) parent->l = node;
 37        else parent->r = node;
 38        ++cnt;
 39        insert_fixup(node);
 40    }
 41
 42    bool erase(const Key& k)
 43    {
 44        Node* node = find_node(k);
 45        if (node == nil) return false;
 46        Node* repl = node;
 47        Color repl_original_c = repl->c;
 48        Node* cur;
 49        if (node->l == nil) {
 50            cur = node->r;
 51            transplant(node, node->r);
 52        }
 53        else if (node->r == nil) {
 54            cur = node->l;
 55            transplant(node, node->l);
 56        }
 57        else {
 58            repl = minimum(node->r);
 59            repl_original_c = repl->c;
 60            cur = repl->r;
 61            if (repl->p == node) {
 62                cur->p = repl;
 63            }
 64            else {
 65                transplant(repl, repl->r);
 66                repl->r = node->r;
 67                repl->r->p = repl;
 68            }
 69            transplant(node, repl);
 70            repl->l = node->l;
 71            repl->l->p = repl;
 72            repl->c = node->c;
 73        }
 74        delete node;
 75        --cnt;
 76        if (repl_original_c == BLACK) erase_fixup(cur);
 77        return true;
 78    }
 79
 80    bool empty() const { return root == nil; }
 81    size_t size() const { return cnt; }
 82    bool contains(const Key& k) const { return find_node(k) != nil; }
 83    Value* get(const Key& k)
 84    {
 85        Node* n = find_node(k);
 86        return n != nil ? &n->val : nullptr;
 87    }
 88
 89private:
 90    enum Color { RED, BLACK };
 91    struct Node {
 92        Key key;
 93        Value val;
 94        Node *l, *r, *p;
 95        Color c;
 96        Node(const Node&) = delete;
 97        Node& operator=(const Node&) = delete;
 98        Node(const Key& k = {}, const Value& v = Value {}, Color col = RED) :
 99            key(k), val(v), l(nullptr), r(nullptr), p(nullptr), c(col) {}
100        auto right_child() const -> bool { return this == p->r; }
101    };
102    Node *root, *nil;
103    size_t cnt = 0;
104
105    void rotate_left(Node* parent)
106    {
107        Node* succ = parent->r;
108        parent->r = succ->l;
109        if (succ->l != nil) succ->l->p = parent;
110        succ->p = parent->p;
111        if (parent->p == nil) root = succ;
112        else if (parent == parent->p->l) parent->p->l = succ;
113        else parent->p->r = succ;
114        succ->l = parent;
115        parent->p = succ;
116    }
117
118    void rotate_right(Node* parent)
119    {
120        Node* succ = parent->l;
121        parent->l = succ->r;
122        if (succ->r != nil) succ->r->p = parent;
123        succ->p = parent->p;
124        if (parent->p == nil) root = succ;
125        else if (parent == parent->p->r) parent->p->r = succ;
126        else parent->p->l = succ;
127        succ->r = parent;
128        parent->p = succ;
129    }
130
131    void insert_fixup(Node* cur)
132    {
133        while (cur->p->c == RED) {
134            if (cur->p == cur->p->p->l) {
135                Node* uncle = cur->p->p->r;
136                if (uncle->c == RED) {
137                    cur->p->c = uncle->c = BLACK;
138                    cur->p->p->c = RED;  // 叔父爷颜色取反
139                    cur = cur->p->p;     // 继续考虑爷节点
140                }
141                else {
142                    /*   [3]        [3]      [2]
143                     *   /          /        / \
144                     *  1    ->    2    ->  1   3
145                     *   \        /
146                     *   *2     *1
147                     */
148                    if (cur == cur->p->r) {
149                        cur = cur->p;
150                        rotate_left(cur);
151                    }
152                    cur->p->c = BLACK;
153                    cur->p->p->c = RED;
154                    rotate_right(cur->p->p);
155                }
156            }
157            else {
158                Node* uncle = cur->p->p->l;
159                if (uncle->c == RED) {
160                    cur->p->c = uncle->c = BLACK;
161                    cur->p->p->c = RED;
162                    cur = cur->p->p;
163                }
164                else {
165                    if (cur == cur->p->l) {
166                        cur = cur->p;
167                        rotate_right(cur);
168                    }
169                    cur->p->c = BLACK;
170                    cur->p->p->c = RED;
171                    rotate_left(cur->p->p);
172                }
173            }
174        }
175        root->c = BLACK;
176    }
177
178    Node* find_node(const Key& k) const
179    {
180        Node* cur = root;
181        while (cur != nil && cur->key != k) {
182            cur = (k < cur->key) ? cur->l : cur->r;
183        }
184        return cur;
185    }
186
187    void transplant(Node* u, Node* v)
188    {
189        if (u->p == nil) root = v;
190        else if (u == u->p->l) u->p->l = v;
191        else u->p->r = v;
192        v->p = u->p;
193    }
194
195    Node* minimum(Node* cur) const
196    {
197        while (cur->l != nil)
198            cur = cur->l;
199        return cur;
200    }
201
202    void erase_fixup(Node* cur)
203    {
204        while (cur != root && cur->c == BLACK) {
205            if (cur == cur->p->l) {
206                Node* sib = cur->p->r;
207                /*
208                 *     [2]            [4]
209                 *     / \            / \
210                 *  *[n]  4    ->    2  [5]
211                 *       / \        / \
212                 *     [3] [5]   *[n] [3]
213                 */
214                if (sib->c == RED) {
215                    sib->c = BLACK;
216                    cur->p->c = RED;
217                    rotate_left(cur->p);
218                    sib = cur->p->r;
219                }
220                /*       [4]           [4]
221                 *       / \           / \
222                 *      2  [5]  ->   *2  [5]
223                 *     / \             \
224                 *  *[n] [3]            3
225                 */
226                if (sib->l->c == BLACK && sib->r->c == BLACK) {
227                    sib->c = RED;
228                    cur = cur->p;
229                }
230                else {
231                    /*      2           2
232                     *     / \         / \
233                     *  *[n] [4] -> *[n] [3]
234                     *       /             \
235                     *      3               4
236                     */
237                    if (sib->r->c == BLACK) {
238                        sib->l->c = BLACK;
239                        sib->c = RED;
240                        rotate_right(sib);
241                        sib = cur->p->r;
242                    }
243                    /*      2          *3
244                     *     / \         / \
245                     *  *[n] [3] ->  [2] [4]
246                     *         \
247                     *          4
248                     */
249                    sib->c = cur->p->c;
250                    cur->p->c = BLACK;
251                    sib->r->c = BLACK;
252                    rotate_left(cur->p);
253                    cur = root;
254                }
255            }
256            else {
257                Node* sib = cur->p->l;
258                if (sib->c == RED) {
259                    sib->c = BLACK;
260                    cur->p->c = RED;
261                    rotate_right(cur->p);
262                    sib = cur->p->l;
263                }
264                if (sib->r->c == BLACK && sib->l->c == BLACK) {
265                    sib->c = RED;
266                    cur = cur->p;
267                }
268                else {
269                    if (sib->l->c == BLACK) {
270                        sib->r->c = BLACK;
271                        sib->c = RED;
272                        rotate_left(sib);
273                        sib = cur->p->l;
274                    }
275                    sib->c = cur->p->c;
276                    cur->p->c = BLACK;
277                    sib->l->c = BLACK;
278                    rotate_right(cur->p);
279                    cur = root;
280                }
281            }
282        }
283        cur->c = BLACK;
284    }
285
286    void clear_node(Node* cur)
287    {
288        if (cur == nil) return;
289        clear_node(cur->l);
290        clear_node(cur->r);
291        delete cur;
292    }
293};