/// Randomized Treap (Cartesian Tree) /// Store x (key), simulate y (priority) /// Store value /// Maintain size of subtree /// total: sum of values in subtree /// largest: maximum of values in subtree /// Base Functions /// tSplit: split into =x /// tMerge: merge trees l and r if {keys of l} <= {keys of r} /// Map Interface (Multimap) /// tInsert: add (x, value) pair into map /// tFind: return value for given key x /// tRemove: erase one key x from map /// Debugging /// tOutput: print bracket structure of tree /// tOutput2: print vertical representation /// Example Usage /// Add 10^7 elements /// Cut middle part, check total and largest /// Find 10^7 elements /// (disabled) Remove 10^7 elements #include using namespace std; auto gen = mt19937 (time (0)); using uniform = uniform_int_distribution ; struct Node; using PNode = Node *; using Num = long long; int getSize (PNode t); Num getTotal (PNode t); Num getLargest (PNode t); struct Node { int x; Num value; Num total; Num largest; int size; PNode left; PNode right; void recalc () { // size = 1 + getSize (this -> left) + getSize (this -> right); size = 1 + getSize (left) + getSize (right); total = value + getTotal (left) + getTotal (right); largest = max (value, max (getLargest (left), getLargest (right))); } Node (int x_, Num value_) { x = x_; value = value_; total = value_; largest = value_; size = 1; left = nullptr; right = nullptr; } ~Node () { if (left != nullptr) delete left; if (right != nullptr) delete right; } }; int getSize (PNode t) {return (t == nullptr) ? 0 : t -> size;} Num getTotal (PNode t) {return (t == nullptr) ? Num (0) : t -> total;} Num getLargest (PNode t) {return (t == nullptr) ? Num (LLONG_MIN) : t -> largest;} pair tSplit (PNode t, int x) { if (t == nullptr) return {nullptr, nullptr}; if (t -> x >= x) { auto p = tSplit (t -> left, x); t -> left = p.second; t -> recalc (); return {p.first, t}; } else { auto p = tSplit (t -> right, x); t -> right = p.first; t -> recalc (); return {t, p.second}; } } PNode tMerge (PNode l, PNode r) { if (l == nullptr) return r; if (r == nullptr) return l; if (uniform (0, l -> size + r -> size - 1) (gen) >= l -> size) { r -> left = tMerge (l, r -> left); r -> recalc (); return r; } else { l -> right = tMerge (l -> right, r); l -> recalc (); return l; } } PNode tInsert (PNode t, int x, Num value) { auto v = new Node (x, value); auto p = tSplit (t, x); auto half = tMerge (p.first, v); return tMerge (half, p.second); } void tOutput (PNode t, bool doEndl = true) { if (t != nullptr) { cout << '('; tOutput (t -> left, false); cout << t -> x << ':' << t -> value << '|' << t -> size; tOutput (t -> right, false); cout << ')'; } if (doEndl) cout << '.' << endl; } void tOutput2 (PNode t, int depth = 1) { if (t != nullptr) { tOutput2 (t -> left, depth + 1); // cout << setw (depth) << ' ' << t -> x << endl; cout << setw (depth * 4) << ' ' << t -> x << " value=" << t -> value << " total=" << t -> total << " largest=" << t -> largest << endl; tOutput2 (t -> right, depth + 1); } } Num tFind (PNode t, int x) { if (t == nullptr) return Num (-1); if (t -> x == x) return t -> value; if (t -> x < x) return tFind (t -> right, x); else return tFind (t -> left, x); } PNode tRemove (PNode t, int x) { if (t == nullptr) return t; if (t -> x == x) { auto res = tMerge (t -> left, t -> right); t -> left = nullptr; t -> right = nullptr; delete t; return res; } if (t -> x < x) { t -> right = tRemove (t -> right, x); t -> recalc (); return t; } else { t -> left = tRemove (t -> left, x); t -> recalc (); return t; } } int main () { /* Node u; u.x; u.y; u.left; u.right; PNode v; (*v).x; (*v).y; v -> x; v -> y; */ int const n = int (1E7); PNode root = nullptr; for (int x = 1; x <= n; x++) { root = tInsert (root, x * 2, (x * 1LL * x) % 17); if (x <= 10) {tOutput2 (root); cout << endl;} } { auto p = tSplit (root, 3333333); auto q = tSplit (p.second, 6666666); cout << q.first -> total << " " << q.first -> largest << endl; p.second = tMerge (q.first, q.second); root = tMerge (p.first, p.second); } for (int x = 1; x <= n; x++) { auto res = tFind (root, x); if (x <= 10) {cout << x << ' ' << res << endl;} } if (false) for (int x = 1; x <= n; x++) { root = tRemove (root, x * 2); if (n - x <= 10) tOutput (root); } return 0; }