经典的树套树,顺便测一下自己的模板。奇怪的是使用通用模板反而比特别优化的直接寻址要快。
这个是最终版本:
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; #define MAXN 50005 template <class T> struct Splay { struct Node { Node *_fa, *_ch[2]; T _val; int _size; Node() {} Node(const T& x) { set_value(x); } void *operator new(size_t) { if (_stop != 0) return _st[--_stop]; else return _nodes + _mtot++; } void operator delete(void* p, size_t) { _st[_stop++] = (Node*)p; } void set_value(const T& x) { _val = x; _size = 1; _ch[0] = _ch[1] = 0; _fa = _fa_root; } Node* get_end(int c) { // 0 表示左端 1 表示右端 Node* x = this; while (x->_ch[c]) x = x->_ch[c]; return x; } Node* get_near(int c) { if (_ch[c]) return _ch[c]->get_end(1 - c); Node* x = this; while (x->_fa && x->_fa->_ch[c] == x) x = x->_fa; return x->_fa; } Node* get_succ() { return get_near(1); } Node* get_prev() { return get_near(0); } void push_down() {} // 向下传标记 void update() { _size = 1; if (_ch[0]) _size += _ch[0]->_size; if (_ch[1]) _size += _ch[1]->_size; } int get_size() { if (this) return _size; else return 0; } }; Node *_root; static Node *const _fa_root, _nodes[(MAXN + 1) * 18], *_st[(MAXN + 1) * 18]; static int _mtot, _stop; void rotate(Node* x, int c) { // 旋转操作,0 表示左旋,1 表示右旋 Node* y = x->_fa; y->push_down(); x->push_down(); // y 在 x 之上,故先 y 后 x y->_ch[1 - c] = x->_ch[c]; if (x->_ch[c]) x->_ch[c]->_fa = y; x->_fa = y->_fa; if (y->_fa != _fa_root) { if (y->_fa->_ch[0] == y) y->_fa->_ch[0] = x; else y->_fa->_ch[1] = x; } else _root = x; x->_ch[c] = y; y->_fa = x; y->update(); // 维护 y } void splay(Node* x, Node *f) { // 把结点 x 转到结点 f 下面 x->push_down(); while (x->_fa != f) { if (x->_fa->_fa == f) { if (x->_fa->_ch[0] == x) rotate(x, 1); else rotate(x, 0); } else { Node *y = x->_fa, *z = y->_fa; if (z->_ch[0] == y) { if (y->_ch[0] == x) rotate(y, 1), rotate(x, 1); else rotate(x, 0), rotate(x, 1); } else { if (y->_ch[1] == x) rotate(y, 0), rotate(x, 0); else rotate(x, 1), rotate(x, 0); } } } x->update(); } Node* join(Node* r1, Node* r2) { if (r1 && r2) { Node *right_end = r1->get_end(1); splay(right_end, r1->_fa); right_end->_ch[1] = r2; r2->_fa = right_end; right_end->update(); return right_end; } if (r1) return r1; if (r2) return r2; return 0; } void insert(const T& px) { Node* nd = new Node(px); if (_root == 0) { _root = nd; return; } Node *x = _root, *y = _fa_root; while (x) { y = x; if (nd->_val <= x->_val) x = x->_ch[0]; else x = x->_ch[1]; } if (nd->_val <= y->_val) y->_ch[0] = nd; else y->_ch[1] = nd; nd->_fa = y; splay(nd, _fa_root); } void erase(const T& px) { Node* nd = lower_bound(px); if (nd == 0 || nd->_val != px) return; Node* x = join(nd->_ch[0], nd->_ch[1]); if (nd->_fa == _fa_root) { _root = x; } else { if (nd->_fa->_ch[0] == nd) nd->_fa->_ch[0] = x; else nd->_fa->_ch[1] = x; } if (x) x->_fa = nd->_fa; if (nd->_fa) splay(nd->_fa, _fa_root); delete nd; } Node* lower_bound(const T& px) { Node *x = _root, *y = _fa_root; while (x) { y = x; if (px <= x->_val) x = x->_ch[0]; else x = x->_ch[1]; } if (y == _fa_root) return 0; if (px <= y->_val) return y; else return y->get_succ(); } Node* upper_bound(const T& px) { Node *x = _root, *y = _fa_root; while (x) { y = x; if (px < x->_val) x = x->_ch[0]; else x = x->_ch[1]; } if (y == _fa_root) return 0; if (px < y->_val) return y; else return y->get_succ(); } int query_less(const T& val) { Node* x = lower_bound(val); if (x == 0) return _root->get_size(); splay(x, _fa_root); return x->_ch[0]->get_size(); } void clear() { _root = 0; } static void mem_init() { _mtot = _stop = 0; } }; template <class T> typename Splay<T>::Node *const Splay<T>::_fa_root = 0; template <class T> typename Splay<T>::Node Splay<T>::_nodes[(MAXN + 1) * 18]; template <class T> typename Splay<T>::Node* Splay<T>::_st[(MAXN + 1) * 18]; template <class T> int Splay<T>::_stop = 0; template <class T> int Splay<T>::_mtot = 0; int v[MAXN + 1]; struct Node { int _a, _b; Node *_left, *_right; Splay<int> _bst; } nodes[MAXN * 2], *ptr; void build_tree(int a, int b) { Node* root = ++ptr; root->_a = a; root->_b = b; root->_bst.clear(); root->_bst.insert(v[a]); if (a == b) { root->_left = root->_right = 0; return; } for (int i = a + 1; i <= b; ++i) root->_bst.insert(v[i]); root->_left = ptr + 1; build_tree(a, (a + b) / 2); root->_right = ptr + 1; build_tree((a + b) / 2 + 1, b); } int la, lb, lk; void modify(Node* root) { if (root->_a == root->_b) { root->_bst.erase(v[la]); root->_bst.insert(v[la] = lk); return; } root->_bst.erase(v[la]); root->_bst.insert(lk); int mid = (root->_a + root->_b) / 2; if (la <= mid) modify(root->_left); else modify(root->_right); } int query(Node* root) { if (la <= root->_a && root->_b <= lb) { return root->_bst.query_less(lk); } int mid = (root->_a + root->_b) / 2; int res = 0; if (la <= mid) res = query(root->_left); if (mid < lb) res += query(root->_right); return res; } inline int solve(int a, int b, int k) { int l = 0, r = 1000000001, mid; while (l + 1 != r) { mid = (l + r) / 2; la = a; lb = b; lk = mid; if (query(nodes + 1) < k) l = mid; else r = mid; } return l; } int main() { #ifndef ONLINE_JUDGE freopen("in.txt", "r", stdin); freopen("out.txt", "w", stdout); #endif int test; scanf("%d", &test); for (int cas = 1; cas <= test; ++cas) { Splay<int>::mem_init(); int n, m; scanf("%d%d", &n, &m); for (int i = 1; i <= n; ++i) scanf("%d", v + i); ptr = nodes; build_tree(1, n); for (int i = 1; i <= m; ++i) { char c; int a, b, k; scanf(" %c", &c); if (c == 'Q') { scanf("%d%d%d", &a, &b, &k); printf("%d\n", solve(a, b, k)); } else { scanf("%d%d", &a, &k); la = a; lk = k; modify(nodes + 1); } } } return 0; }
第二个版本,就是 query 的时候不进行 splay 操作,比上面的程序略慢一些:
int query_less(const T& val) { int res = 0; Node* x = _root; while (x) { if (val <= x->_val) { x = x->_ch[0]; } else { ++res; if (x->_ch[0]) res += x->_ch[0]->_size; x = x->_ch[1]; } } return res; }
原始版本,8s 水过 = =... 让我不得不怀疑是不是 gcc 的寻址有些问题。
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; #define MAXN 50005 struct Splay { struct Node { Node *_fa, *_ch[2]; int _val, _size; void update() { _size = 1; if (_ch[0]) _size += _ch[0]->_size; if (_ch[1]) _size += _ch[1]->_size; } void set_value(int x) { _size = 1; _val = x; _ch[0] = _ch[1] = 0; _fa = _fa_root; } Node* get_end(int c) { // 0 表示左端 1 表示右端 Node* x = this; while (x->_ch[c]) x = x->_ch[c]; return x; } }; Node *_root; static Node *const _fa_root, _nodes[MAXN + 1][18]; void rotate(Node* x, int c) { // 旋转操作,0 表示左旋,1 表示右旋 Node* y = x->_fa; y->_ch[1 - c] = x->_ch[c]; if (x->_ch[c]) x->_ch[c]->_fa = y; x->_fa = y->_fa; if (y->_fa != _fa_root) { if (y->_fa->_ch[0] == y) y->_fa->_ch[0] = x; else y->_fa->_ch[1] = x; } else _root = x; x->_ch[c] = y; y->_fa = x; y->update(); } void splay(Node* x, Node *f) { // 把结点 x 转到结点 f 下面 while (x->_fa != f) { if (x->_fa->_fa == f) { if (x->_fa->_ch[0] == x) rotate(x, 1); else rotate(x, 0); } else { Node *y = x->_fa, *z = y->_fa; if (z->_ch[0] == y) { if (y->_ch[0] == x) rotate(y, 1), rotate(x, 1); else rotate(x, 0), rotate(x, 1); } else { if (y->_ch[1] == x) rotate(y, 0), rotate(x, 0); else rotate(x, 1), rotate(x, 0); } } } x->update(); } Node* join(Node* r1, Node* r2) { if (r1 && r2) { Node *right_end = r1->get_end(1); splay(right_end, r1->_fa); right_end->_ch[1] = r2; r2->_fa = right_end; right_end->update(); return right_end; } if (r1) return r1; if (r2) return r2; return 0; } void insert(Node* nd) { Node *x = _root, *y = _fa_root; while (x) { y = x; if (nd->_val < x->_val) x = x->_ch[0]; else x = x->_ch[1]; } if (nd->_val < y->_val) y->_ch[0] = nd; else y->_ch[1] = nd; nd->_fa = y; splay(nd, _fa_root); } void erase(Node* nd) { Node* x = join(nd->_ch[0], nd->_ch[1]); if (nd->_fa == _fa_root) { _root = x; } else { if (nd->_fa->_ch[0] == nd) nd->_fa->_ch[0] = x; else nd->_fa->_ch[1] = x; } if (x) x->_fa = nd->_fa; if (nd->_fa) // 需要更新的是父结点以及父的所有父,x 结点在 join 时已更新 splay(nd->_fa, _fa_root); } int query_less(int val) { int res = 0; Node* x = _root; while (x) { if (val <= x->_val) { x = x->_ch[0]; } else { ++res; if (x->_ch[0]) res += x->_ch[0]->_size; x = x->_ch[1]; } } return res; } }; Splay::Node *const Splay::_fa_root = 0, Splay::_nodes[MAXN + 1][18]; int v[MAXN + 1]; struct Node { int _a, _b, _dep; Node *_left, *_right; Splay _bst; } nodes[MAXN * 2], *ptr; void build_tree(int a, int b, int dep) { Node* root = ++ptr; root->_a = a; root->_b = b; root->_dep = dep; root->_bst._root = Splay::_nodes[a] + dep; root->_bst._root->set_value(v[a]); if (a == b) { root->_left = root->_right = 0; return; } for (int i = a + 1; i <= b; ++i) { Splay::_nodes[i][dep].set_value(v[i]); root->_bst.insert(Splay::_nodes[i] + dep); } root->_left = ptr + 1; build_tree(a, (a + b) / 2, dep + 1); root->_right = ptr + 1; build_tree((a + b) / 2 + 1, b, dep + 1); } int la, lb, lk; void modify(Node* root) { if (root->_a == root->_b) { root->_bst._root->set_value(v[root->_a] = lk); return; } root->_bst.erase(Splay::_nodes[la] + root->_dep); Splay::_nodes[la][root->_dep].set_value(lk); root->_bst.insert(Splay::_nodes[la] + root->_dep); int mid = (root->_a + root->_b) / 2; if (la <= mid) modify(root->_left); else modify(root->_right); } int query(Node* root) { if (la <= root->_a && root->_b <= lb) { return root->_bst.query_less(lk); } int mid = (root->_a + root->_b) / 2; int res = 0; if (la <= mid) res = query(root->_left); if (mid < lb) res += query(root->_right); return res; } inline int solve(int a, int b, int k) { int l = 0, r = 1000000001, mid; while (l + 1 != r) { mid = (l + r) / 2; la = a; lb = b; lk = mid; if (query(nodes + 1) < k) l = mid; else r = mid; } return l; } int main() { #ifndef ONLINE_JUDGE freopen("in.txt", "r", stdin); freopen("out.txt", "w", stdout); #endif int test; scanf("%d", &test); for (int cas = 1; cas <= test; ++cas) { int n, m; scanf("%d%d", &n, &m); for (int i = 1; i <= n; ++i) scanf("%d", v + i); ptr = nodes; build_tree(1, n, 0); for (int i = 1; i <= m; ++i) { char c; int a, b, k; scanf(" %c", &c); if (c == 'Q') { scanf("%d%d%d", &a, &b, &k); printf("%d\n", solve(a, b, k)); } else { scanf("%d%d", &a, &k); la = a; lk = k; modify(nodes + 1); } } } return 0; }