树套树的应用与实现
树套树是一种数据结构嵌套的方式,通常用于处理复杂查询。外层可以是线段树或树状数组,而内层则可以是平衡树或线段树。
常用组合:
- 外层:线段树、树状数组
- 内层:平衡树、线段树(通常使用标准库)
示例题解
示例题目一
这是一个简单的线段树套set的例子:
#include <iostream>
#include <set>
using namespace std;
const int MAXN = 50005, MAXM = MAXN * 4;
int n, m;
struct SegmentTree {
int l, r;
multiset<int> s;
} tree[MAXM];
int w[MAXN];
void build(int u, int l, int r) {
tree[u] = {l, r};
tree[u].s.insert(INT_MIN);
tree[u].s.insert(INT_MAX);
for (int i = l; i <= r; ++i)
tree[u].s.insert(w[i]);
if (l == r) return;
int mid = (l + r) / 2;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
void update(int u, int p, int x) {
tree[u].s.erase(tree[u].s.find(w[p]));
tree[u].s.insert(x);
if (tree[u].l == tree[u].r) return;
int mid = (tree[u].l + tree[u].r) / 2;
if (p <= mid)
update(u << 1, p, x);
else
update(u << 1 | 1, p, x);
}
int query(int u, int a, int b, int x) {
if (tree[u].l >= a && tree[u].r <= b) {
auto it = --tree[u].s.lower_bound(x);
return *it;
}
int mid = (tree[u].l + tree[u].r) / 2;
int res = INT_MIN;
if (a <= mid)
res = max(res, query(u << 1, a, b, x));
if (b > mid)
res = max(res, query(u << 1 | 1, a, b, x));
return res;
}
int main() {
cin >> n >> m;
for (int i = 1; i <= n; ++i)
cin >> w[i];
build(1, 1, n);
while (m--) {
int op, a, b, x;
cin >> op;
if (op == 1) {
cin >> a >> x;
update(1, a, x);
w[a] = x;
} else {
cin >> a >> b >> x;
cout << query(1, a, b, x) << endl;
}
}
return 0;
}
示例题目二
这是另一个示例,展示了如何使用线段树套平衡树来解决更复杂的问题:
#include <iostream>
#include <algorithm>
using namespace std;
const int MAXN = 2000005, INF = 2147483647;
int n, m;
struct Node {
int child[2], parent, value;
int size;
void init(int v, int p) {
value = v;
parent = p;
size = 1;
}
} nodes[MAXN];
int leftChild[MAXN], rightChild[MAXN], root[MAXN], indexCounter;
int values[MAXN];
void pushUp(int x) {
nodes[x].size = nodes[nodes[x].child[0]].size + nodes[nodes[x].child[1]].size + 1;
}
void rotate(int x) {
int y = nodes[x].parent, z = nodes[y].parent;
int k = nodes[y].child[1] == x;
nodes[z].child[nodes[z].child[1] == y] = x;
nodes[x].parent = z;
nodes[y].child[k] = nodes[x].child[k ^ 1];
nodes[nodes[x].child[k ^ 1]].parent = y;
nodes[x].child[k ^ 1] = y;
nodes[y].parent = x;
pushUp(y);
pushUp(x);
}
void splay(int &root, int x, int target) {
while (nodes[x].parent != target) {
int y = nodes[x].parent, z = nodes[y].parent;
if (z != target)
rotate((nodes[y].child[1] == x) ^ (nodes[z].child[1] == y) ? x : y);
rotate(x);
}
if (!target)
root = x;
}
void insert(int &root, int val) {
int node = root, parent = 0;
while (node)
parent = node, node = nodes[node].child[val > nodes[node].value];
node = ++indexCounter;
if (parent)
nodes[parent].child[val > nodes[parent].value] = node;
nodes[node].init(val, parent);
splay(root, node, 0);
}
int getKth(int root, int val) {
int node = root, rank = 0;
while (node) {
if (nodes[node].value < val)
rank += nodes[nodes[node].child[0]].size + 1, node = nodes[node].child[1];
else
node = nodes[node].child[0];
}
return rank;
}
void modify(int &root, int oldVal, int newVal) {
int node = root;
while (node) {
if (nodes[node].value == oldVal)
break;
if (nodes[node].value < oldVal)
node = nodes[node].child[1];
else
node = nodes[node].child[0];
}
splay(root, node, 0);
int left = nodes[node].child[0], right = nodes[node].child[1];
while (nodes[left].child[1])
left = nodes[left].child[1];
while (nodes[right].child[0])
right = nodes[right].child[0];
splay(root, left, 0), splay(root, right, left);
nodes[right].child[0] = 0;
pushUp(left), pushUp(right);
insert(root, newVal);
}
void build(int u, int l, int r) {
leftChild[u] = l, rightChild[u] = r;
insert(root[u], INF), insert(root[u], -INF);
for (int i = l; i <= r; ++i)
insert(root[u], values[i]);
if (l == r) return;
int mid = (l + r) / 2;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
int query(int u, int a, int b, int x) {
if (leftChild[u] >= a && rightChild[u] <= b)
return getKth(root[u], x) - 1;
int mid = (leftChild[u] + rightChild[u]) / 2;
int result = 0;
if (a <= mid)
result += query(u << 1, a, b, x);
if (b > mid)
result += query(u << 1 | 1, a, b, x);
return result;
}
void change(int u, int p, int x) {
modify(root[u], values[p], x);
if (leftChild[u] == rightChild[u]) return;
int mid = (leftChild[u] + rightChild[u]) / 2;
if (p <= mid)
change(u << 1, p, x);
else
change(u << 1 | 1, p, x);
}
int main() {
cin >> n >> m;
for (int i = 1; i <= n; ++i)
cin >> values[i];
build(1, 1, n);
while (m--) {
int op, a, b, x;
cin >> op;
if (op == 1) {
cin >> a >> b >> x;
cout << query(1, a, b, x) + 1 << endl;
} else if (op == 2) {
cin >> a >> b >> x;
int low = 0, high = 1e8;
while (low < high) {
int mid = (low + high + 1) / 2;
if (query(1, a, b, mid) + 1 <= x)
low = mid;
else
high = mid - 1;
}
cout << high << endl;
} else if (op == 3) {
cin >> a >> x;
change(1, a, x);
values[a] = x;
} else if (op == 4) {
cin >> a >> b >> x;
cout << getPredecessor(root[1], a, b, x) << endl;
} else {
cin >> a >> b >> x;
cout << getSuccessor(root[1], a, b, x) << endl;
}
}
return 0;
}