当前位置:首页 > 技术 > 正文内容

高级数据结构:莫队算法、链剖分与分块技术

访客 技术 2026年5月23日 3

莫队算法概述

莫队算法是一种用于解决静态区间查询问题的离线算法。当已知区间 [l,r] 的答案能够通过移动端点扩展到 [l±1,r±1] 获得时,我们可以在 O(n√n) 的时间复杂度内处理所有查询。设查询数量与数组规模同阶。

算法思想

莫队的核心思想较为直接:将所有查询离线后进行排序,然后暴力地将前一个区间的答案逐步移动到下一个区间。排序策略通常以左端点所在块的编号为第一关键字,右端点为第二关键字,有时还会加入奇偶性优化。块长一般设置为 n/√m,不同的变体会对应不同的最优块长。

基本框架

long long left = 1, right = 0, answer = 0;
for (int i = 1; i <= m; ++i) {
    while (left > query[i].left) add(--left);
    while (left < query[i].left) del(left++);
    while (right > query[i].right) del(right--);
    while (right < query[i].right) add(++right);
    result[query[i].id] = answer;
}

通过大量练习即可掌握莫队算法的精髓。

回滚莫队

回滚莫队用于处理区间转移时增加或删除操作无法同时实现的问题。当仅能实现增加操作或仅能实现删除操作时,可以采用回滚莫队在 O(n√m) 时间内解决问题。其核心思想是:只使用可行的操作,将另一步操作留待回滚处理。回滚莫队分为只增型和只减型两种。

只增莫队实例:区间最大频率值

对于某些问题,增加元素容易处理,但删除元素时难以确定剩余区间中的最大值。此时应使用只增莫队。

具体实现步骤:

  1. 对原序列进行分块,将查询按左端点所属块编号升序排列,右端点升序为第二关键字。
  2. 按顺序处理询问:
    • 若当前询问的左端点所属块与上一询问不同,则将莫队区间左端点初始化为该块右端点加1,右端点初始化为该块右端点。
    • 若左右端点同属一块,直接扫描区间求解。
    • 若跨块:首先扩展右端点至目标位置,然后扩展左端点,记录答案后撤销左端点改动。

块长最优取 n/√m

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll MAXN = 200000;
ll n, m, arr[MAXN], answer[MAXN], current;
ll blockSize, blockCnt, startPos[MAXN], endPos[MAXN], blockId[MAXN];
ll compressed[MAXN], value[MAXN], frequency[MAXN];

struct Query {
    ll left, right, id;
};

Query queries[MAXN];

bool cmp(const Query& a, const Query& b) {
    return blockId[a.left] == blockId[b.left] ? a.right < b.right : blockId[a.left] < blockId[b.left];
}

ll directSolve(ll l, ll r) {
    ll ans = 0;
    static ll temp[MAXN];
    for (int i = l; i <= r; ++i) temp[compressed[i]] = 0;
    for (int i = l; i <= r; ++i) {
        temp[compressed[i]]++;
        ans = max(ans, temp[compressed[i]] * arr[i]);
    }
    return ans;
}

void addElement(ll x) {
    ++frequency[compressed[x]];
    current = max(current, frequency[compressed[x]] * arr[x]);
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    blockSize = sqrt(n);
    blockCnt = ceil(n * 1.0 / blockSize);
    
    for (int i = 1; i <= n; ++i) {
        cin >> arr[i];
        compressed[i] = value[i] = arr[i];
    }
    sort(value + 1, value + n + 1);
    ll distinct = unique(value + 1, value + n + 1) - value - 1;
    for (int i = 1; i <= n; ++i) 
        compressed[i] = lower_bound(value + 1, value + distinct + 1, compressed[i]) - value;
    
    for (int i = 1; i <= n; ++i) {
        startPos[i] = blockSize * (i - 1) + 1;
        endPos[i] = min(blockSize * i, n);
        for (int j = startPos[i]; j <= endPos[i]; ++j) blockId[j] = i;
    }
    
    for (int i = 1; i <= m; ++i) {
        cin >> queries[i].left >> queries[i].right;
        queries[i].id = i;
    }
    sort(queries + 1, queries + m + 1, cmp);
    
    ll left = 1, right = 0;
    for (int i = 1, idx = 1; idx <= blockCnt; ++idx) {
        left = endPos[idx] + 1;
        right = endPos[idx];
        current = 0;
        fill(frequency + 1, frequency + n + 1, 0);
        
        while (blockId[queries[i].left] == idx) {
            if (blockId[queries[i].left] == blockId[queries[i].right]) {
                answer[queries[i].id] = directSolve(queries[i].left, queries[i].right);
                ++i;
                continue;
            }
            while (right < queries[i].right) addElement(++right);
            ll saved = current;
            while (left > queries[i].left) addElement(--left);
            answer[queries[i].id] = current;
            while (left <= endPos[idx]) {
                --frequency[compressed[left]];
                ++left;
            }
            current = saved;
            ++i;
        }
    }
    
    for (int i = 1; i <= m; ++i) cout << answer[i] << '\n';
    return 0;
}

只减莫队实例:区间mex

当删除元素容易而增加元素困难时,使用只减莫队。排序时右端点需按降序排列。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll MAXN = 300000;
ll n, m, arr[MAXN], answer[MAXN], current;
ll blockSize, blockCnt, startPos[MAXN], endPos[MAXN], blockId[MAXN];
ll freq[MAXN];
bool exists[MAXN];

struct Query {
    ll left, right, id;
};

Query queries[MAXN];

bool cmp(const Query& a, const Query& b) {
    if (blockId[a.left] != blockId[b.left]) 
        return blockId[a.left] < blockId[b.left];
    return a.right > b.right;
}

ll directSolve(ll l, ll r) {
    ll mex = 0;
    for (int i = l; i <= r; ++i) {
        exists[arr[i]] = true;
        while (exists[mex]) ++mex;
    }
    for (int i = l; i <= r; ++i) exists[arr[i]] = false;
    return mex;
}

void removeElement(ll x) {
    --freq[arr[x]];
    if (!freq[arr[x]]) current = min(current, arr[x]);
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    blockSize = n / sqrt(m);
    blockCnt = ceil(n * 1.0 / blockSize);
    
    for (int i = 1; i <= n; ++i) {
        cin >> arr[i];
        ++freq[arr[i]];
    }
    
    for (int i = 1; i <= m; ++i) {
        cin >> queries[i].left >> queries[i].right;
        queries[i].id = i;
    }
    
    for (int i = 1; i <= blockCnt; ++i) {
        startPos[i] = blockSize * (i - 1) + 1;
        endPos[i] = min(blockSize * i, n);
        for (int j = startPos[i]; j <= endPos[i]; ++j) blockId[j] = i;
    }
    
    sort(queries + 1, queries + m + 1, cmp);
    
    ll left, right, initial = 0;
    while (freq[initial]) ++initial;
    
    for (int i = 1, idx = 1; idx <= blockCnt; ++idx) {
        left = startPos[idx];
        right = n;
        current = initial;
        
        while (blockId[queries[i].left] == idx) {
            if (blockId[queries[i].left] == blockId[queries[i].right]) {
                answer[queries[i].id] = directSolve(queries[i].left, queries[i].right);
                ++i;
                continue;
            }
            while (right > queries[i].right) removeElement(right--);
            ll saved = current;
            while (left < queries[i].left) removeElement(left++);
            answer[queries[i].id] = current;
            while (left > startPos[idx]) {
                --left;
                ++freq[arr[left]];
            }
            current = saved;
            ++i;
        }
        
        while (right < n) {
            ++right;
            ++freq[arr[right]];
        }
        while (left < startPos[idx + 1]) {
            --freq[arr[left]];
            if (!freq[arr[left]]) initial = min(initial, arr[left]);
            ++left;
        }
    }
    
    for (int i = 1; i <= m; ++i) cout << answer[i] << '\n';
    return 0;
}

模板应用:区间相同元素最远距离

对于区间内相同元素的最远距离问题,记录每个数值的最左和最右位置,在扩展过程中不断更新最大距离。每个块处理完毕后需清空记录。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll MAXN = 300000;
ll n, m, arr[MAXN], comp[MAXN], answer[MAXN], current;
ll blockSize, blockCnt, startPos[MAXN], endPos[MAXN], blockId[MAXN];
ll leftmost[MAXN], rightmost[MAXN], modified[MAXN];

struct Query {
    ll left, right, id;
};

Query queries[MAXN];

bool cmp(const Query& a, const Query& b) {
    return blockId[a.left] == blockId[b.left] ? a.right < b.right : blockId[a.left] < blockId[b.left];
}

ll directSolve(ll l, ll r) {
    static ll pos[MAXN];
    ll ans = 0;
    for (int i = l; i <= r; ++i) pos[comp[i]] = 0;
    for (int i = l; i <= r; ++i) {
        if (!pos[comp[i]]) pos[comp[i]] = i;
        else ans = max(ans, i - pos[comp[i]]);
    }
    return ans;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n;
    blockSize = sqrt(n);
    blockCnt = ceil(n * 1.0 / blockSize);
    
    for (int i = 1; i <= n; ++i) {
        cin >> arr[i];
        comp[i] = arr[i];
    }
    sort(comp + 1, comp + n + 1);
    ll distinct = unique(comp + 1, comp + n + 1) - comp - 1;
    for (int i = 1; i <= n; ++i) 
        arr[i] = lower_bound(comp + 1, comp + distinct + 1, arr[i]) - comp;
    
    for (int i = 1; i <= n; ++i) {
        startPos[i] = blockSize * (i - 1) + 1;
        endPos[i] = min(blockSize * i, n);
        for (int j = startPos[i]; j <= endPos[i]; ++j) blockId[j] = i;
    }
    
    cin >> m;
    for (int i = 1; i <= m; ++i) {
        cin >> queries[i].left >> queries[i].right;
        queries[i].id = i;
    }
    sort(queries + 1, queries + m + 1, cmp);
    
    ll left, right, modifyCnt;
    for (int i = 1, idx = 1; idx <= blockCnt; ++idx) {
        left = endPos[idx] + 1;
        right = endPos[idx];
        current = modifyCnt = 0;
        
        while (blockId[queries[i].left] == idx) {
            if (blockId[queries[i].left] == blockId[queries[i].right]) {
                answer[queries[i].id] = directSolve(queries[i].left, queries[i].right);
                ++i;
                continue;
            }
            while (right < queries[i].right) {
                ++right;
                rightmost[arr[right]] = right;
                modified[++modifyCnt] = arr[right];
                if (!leftmost[arr[right]]) leftmost[arr[right]] = right;
                current = max(current, right - leftmost[arr[right]]);
            }
            ll saved = current;
            while (left > queries[i].left) {
                --left;
                if (!rightmost[arr[left]]) rightmost[arr[left]] = left;
                else current = max(current, rightmost[arr[left]] - left);
            }
            answer[queries[i].id] = current;
            while (left <= endPos[idx]) {
                if (rightmost[arr[left]] == left) rightmost[arr[left]] = 0;
                ++left;
            }
            current = saved;
            ++i;
        }
        for (int j = 1; j <= modifyCnt; ++j) 
            leftmost[modified[j]] = rightmost[modified[j]] = 0;
    }
    
    for (int i = 1; i <= m; ++i) cout << answer[i] << '\n';
    return 0;
}

值域分块与莫队结合

值域分块是对数值范围进行分块的技术,常与莫队算法结合使用以处理更复杂的查询。

区间第k小出现次数

使用莫队配合值域分块,维护每个数的出现次数及每种出现次数的个数。查询时先枚举块并递减k,找到答案所在块后再枚举具体值。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll MAXN = 200000;
ll n, m, arr[MAXN], comp[MAXN], answer[MAXN];
ll blockSize, blockCnt, blockId[MAXN], startPos[MAXN], endPos[MAXN];
ll cntValue[MAXN], cntFreq[MAXN], freqBlock[MAXN];

struct Query {
    ll left, right, k, id;
};

Query queries[MAXN];

bool cmp(const Query& a, const Query& b) {
    if (blockId[a.left] != blockId[b.left]) 
        return blockId[a.left] < blockId[b.left];
    return a.right < b.right;
}

void add(ll x) {
    --freqBlock[blockId[cntValue[x]]];
    --cntFreq[cntValue[x]];
    ++cntValue[x];
    ++cntFreq[cntValue[x]];
    ++freqBlock[blockId[cntValue[x]]];
}

void remove(ll x) {
    --freqBlock[blockId[cntValue[x]]];
    --cntFreq[cntValue[x]];
    --cntValue[x];
    ++cntFreq[cntValue[x]];
    ++freqBlock[blockId[cntValue[x]]];
}

ll query(ll k) {
    ll pos = 0;
    for (int i = 1; i <= blockCnt; ++i) {
        if (k - freqBlock[i] <= 0) {
            pos = i;
            break;
        }
        k -= freqBlock[i];
    }
    if (pos == 0) return -1;
    for (int i = startPos[pos]; i <= endPos[pos]; ++i) {
        if (k - cntFreq[i] <= 0) return i;
        k -= cntFreq[i];
    }
    return -1;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    blockSize = n / sqrt(m);
    blockCnt = ceil(n * 1.0 / blockSize);
    
    for (int i = 1; i <= n; ++i) {
        cin >> arr[i];
        comp[i] = arr[i];
    }
    sort(comp + 1, comp + n + 1);
    ll distinct = unique(comp + 1, comp + n + 1) - comp - 1;
    for (int i = 1; i <= n; ++i) 
        arr[i] = lower_bound(comp + 1, comp + distinct + 1, arr[i]) - comp;
    
    for (int i = 1; i <= n; ++i) {
        startPos[i] = blockSize * (i - 1) + 1;
        endPos[i] = min(blockSize * i, n);
        for (int j = startPos[i]; j <= endPos[i]; ++j) blockId[j] = i;
    }
    
    for (int i = 1; i <= m; ++i) {
        cin >> queries[i].left >> queries[i].right >> queries[i].k;
        queries[i].id = i;
    }
    sort(queries + 1, queries + m + 1, cmp);
    
    ll left = queries[1].left, right = left - 1;
    for (int i = 1; i <= m; ++i) {
        while (right < queries[i].right) add(arr[++right]);
        while (left > queries[i].left) add(arr[--left]);
        while (right > queries[i].right) remove(arr[right--]);
        while (left < queries[i].left) remove(arr[left++]);
        answer[queries[i].id] = query(queries[i].k);
    }
    
    for (int i = 1; i <= m; ++i) cout << answer[i] << '\n';
    return 0;
}

区间值域统计

对于查询区间内值域在 [a,b] 范围内的数的个数及不同数的个数,需要维护三个数组:每个数的出现次数、每个值域块中的总数、每个值域块中不同数的个数。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll MAXN = 200000;
ll n, m, arr[MAXN];
ll blockSize, blockCnt, blockId[MAXN], startPos[MAXN], endPos[MAXN];
ll blockTotal[MAXN], countNum[MAXN], distinctBlock[MAXN];
ll ans1[MAXN], ans2[MAXN];

struct Query {
    ll left, right, valLeft, valRight, id;
};

Query queries[MAXN];

bool cmp(const Query& a, const Query& b) {
    if (blockId[a.left] != blockId[b.left]) 
        return blockId[a.left] < blockId[b.left];
    return (blockId[a.left] & 1) ? a.right < b.right : a.right > b.right;
}

void add(ll x) {
    ++blockTotal[blockId[arr[x]]];
    ++countNum[arr[x]];
    if (countNum[arr[x]] == 1) ++distinctBlock[blockId[arr[x]]];
}

void remove(ll x) {
    --blockTotal[blockId[arr[x]]];
    --countNum[arr[x]];
    if (countNum[arr[x]] == 0) --distinctBlock[blockId[arr[x]]];
}

void solve(ll a, ll b) {
    ans1 = ans2 = 0;
    if (blockId[a] == blockId[b]) {
        for (int i = a; i <= b; ++i) {
            ans1 += countNum[i];
            ans2 += (countNum[i] > 0);
        }
        return;
    }
    for (int i = a; i <= endPos[blockId[a]]; ++i) {
        ans1 += countNum[i];
        ans2 += (countNum[i] > 0);
    }
    for (int i = startPos[blockId[b]]; i <= b; ++i) {
        ans1 += countNum[i];
        ans2 += (countNum[i] > 0);
    }
    for (int i = blockId[a] + 1; i < blockId[b]; ++i) {
        ans1 += blockTotal[i];
        ans2 += distinctBlock[i];
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    blockSize = max(n / sqrt(m), 1LL);
    blockCnt = ceil(n * 1.0 / blockSize);
    
    for (int i = 1; i <= blockCnt; ++i) {
        startPos[i] = blockSize * (i - 1) + 1;
        endPos[i] = min(blockSize * i, n);
        for (int j = startPos[i]; j <= endPos[i]; ++j) {
            cin >> arr[j];
            blockId[j] = i;
        }
    }
    
    for (int i = 1; i <= m; ++i) {
        cin >> queries[i].left >> queries[i].right >> queries[i].valLeft >> queries[i].valRight;
        queries[i].id = i;
    }
    sort(queries + 1, queries + m + 1, cmp);
    
    ll left = 1, right = 0;
    for (int i = 1; i <= m; ++i) {
        while (left > queries[i].left) add(--left);
        while (left < queries[i].left) remove(left++);
        while (right > queries[i].right) remove(right--);
        while (right < queries[i].right) add(++right);
        solve(queries[i].valLeft, queries[i].valRight);
        ans1[queries[i].id] = ans1;
        ans2[queries[i].id] = ans2;
    }
    
    for (int i = 1; i <= m; ++i) cout << ans1[i] << ' ' << ans2[i] << '\n';
    return 0;
}

树链剖分

树链剖分将树分割成若干条链,常用于在树上进行路径或子树操作。常用的有重链剖分和长链剖分。

重链剖分

基本概念:

  • 重儿子:子节点中子树规模最大的节点
  • 轻儿子:除重儿子外的其他子节点
  • 重边:连接节点与其重儿子的边
  • 轻边:其他边
  • 重链:由重边组成的链

通过重链剖分,树被完全分割为若干条重链,可在dfs序上利用数据结构维护。

重要性质:

  • 所有重链将树完全剖分
  • 经过一条轻边时,子树规模至少减半
  • 任意路径可拆分为从LCA向两端各走最多O(logn)条重链

实现:

第一次DFS记录父节点、深度、子树规模、重子节点。第二次DFS记录链顶、dfs序及对应节点编号。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

struct Edge {
    ll to, next;
};

const ll MAXN = 200000;
ll n, m, root, mod;
ll value[MAXN];
Edge edges[2 * MAXN];
ll head[MAXN], edgeCnt;

void addEdge(ll u, ll v) {
    edges[++edgeCnt].to = v;
    edges[edgeCnt].next = head[u];
    head[u] = edgeCnt;
}

ll parent[MAXN], depth[MAXN], size[MAXN], heavy[MAXN];
ll dfn[MAXN], top[MAXN], order[MAXN], dfnCnt;

void dfs1(ll u, ll p) {
    parent[u] = p;
    size[u] = 1;
    for (int i = head[u]; i; i = edges[i].next) {
        ll v = edges[i].to;
        if (v == p) continue;
        depth[v] = depth[u] + 1;
        dfs1(v, u);
        size[u] += size[v];
        if (size[v] > size[heavy[u]]) heavy[u] = v;
    }
}

void dfs2(ll u, ll t, ll& cnt) {
    top[u] = t;
    dfn[u] = ++cnt;
    order[cnt] = u;
    if (!heavy[u]) return;
    dfs2(heavy[u], t, cnt);
    for (int i = head[u]; i; i = edges[i].next) {
        ll v = edges[i].to;
        if (v == parent[u] || v == heavy[u]) continue;
        dfs2(v, v, cnt);
    }
}

struct SegTree {
    struct Node {
        ll l, r, len;
        ll sum, lazy;
    };
    vector<Node> tree;
    SegTree(ll n) {
        tree.resize(4 * n + 5);
    }
    
    void pushup(ll idx) {
        tree[idx].sum = (tree[idx << 1].sum + tree[idx << 1 | 1].sum) % mod;
    }
    
    void pushdown(ll idx) {
        ll& tag = tree[idx].lazy;
        if (!tag) return;
        ll l = idx << 1, r = idx << 1 | 1;
        tree[l].sum = (tree[l].sum + tag * tree[l].len) % mod;
        tree[r].sum = (tree[r].sum + tag * tree[r].len) % mod;
        tree[l].lazy = (tree[l].lazy + tag) % mod;
        tree[r].lazy = (tree[r].lazy + tag) % mod;
        tag = 0;
    }
    
    void build(ll idx, ll l, ll r) {
        tree[idx].l = l;
        tree[idx].r = r;
        tree[idx].len = r - l + 1;
        if (l == r) {
            tree[idx].sum = value[order[l]] % mod;
            return;
        }
        ll mid = (l + r) >> 1;
        build(idx << 1, l, mid);
        build(idx << 1 | 1, mid + 1, r);
        pushup(idx);
    }
    
    void update(ll idx, ll l, ll r, ll k) {
        if (tree[idx].l >= l && tree[idx].r <= r) {
            tree[idx].sum = (tree[idx].sum + k * tree[idx].len) % mod;
            tree[idx].lazy = (tree[idx].lazy + k) % mod;
            return;
        }
        pushdown(idx);
        ll mid = (tree[idx].l + tree[idx].r) >> 1;
        if (l <= mid) update(idx << 1, l, r, k);
        if (r > mid) update(idx << 1 | 1, l, r, k);
        pushup(idx);
    }
    
    ll query(ll idx, ll l, ll r) {
        if (tree[idx].l >= l && tree[idx].r <= r) return tree[idx].sum;
        pushdown(idx);
        ll mid = (tree[idx].l + tree[idx].r) >> 1;
        ll ans = 0;
        if (l <= mid) ans = (ans + query(idx << 1, l, r)) % mod;
        if (r > mid) ans = (ans + query(idx << 1 | 1, l, r)) % mod;
        return ans;
    }
};

void pathUpdate(ll x, ll y, ll k, SegTree& seg) {
    while (top[x] != top[y]) {
        if (depth[top[x]] < depth[top[y]]) swap(x, y);
        seg.update(1, dfn[top[x]], dfn[x], k);
        x = parent[top[x]];
    }
    if (depth[x] > depth[y]) swap(x, y);
    seg.update(1, dfn[x], dfn[y], k);
}

ll pathQuery(ll x, ll y, SegTree& seg) {
    ll ans = 0;
    while (top[x] != top[y]) {
        if (depth[top[x]] < depth[top[y]]) swap(x, y);
        ans = (ans + seg.query(1, dfn[top[x]], dfn[x])) % mod;
        x = parent[top[x]];
    }
    if (depth[x] > depth[y]) swap(x, y);
    return (ans + seg.query(1, dfn[x], dfn[y])) % mod;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m >> root >> mod;
    for (int i = 1; i <= n; ++i) cin >> value[i];
    for (int i = 1; i < n; ++i) {
        ll u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }
    dfs1(root, 0);
    ll cnt = 0;
    dfs2(root, root, cnt);
    SegTree seg(n);
    seg.build(1, 1, n);
    
    for (int i = 1; i <= m; ++i) {
        ll opt, x, y, z;
        cin >> opt;
        if (opt == 1) {
            cin >> x >> y >> z;
            z %= mod;
            pathUpdate(x, y, z, seg);
        }
        if (opt == 2) {
            cin >> x >> y;
            cout << pathQuery(x, y, seg) << '\n';
        }
        if (opt == 3) {
            cin >> x >> z;
            z %= mod;
            seg.update(1, dfn[x], dfn[x] + size[x] - 1, z);
        }
        if (opt == 4) {
            cin >> x;
            cout << seg.query(1, dfn[x], dfn[x] + size[x] - 1) % mod << '\n';
        }
    }
    return 0;
}

分块技巧

分块是一种灵活的数据结构技术,可用于平衡不同操作的复杂度。

复杂度平衡策略

策略一:O(√n) 区间修改,O(1) 区间求和

维护块内前缀和与块间前缀和。查询时直接合并整块与散块,修改时暴力重新计算。

策略二:O(1) 区间加,O(√n) 区间求和

使用差分思想,维护 Σaᵢ 和 Σaᵢ×i 两个前缀和数组。

值域分块

当值域范围为 O(n) 时,可采用类似权值线段树的形式维护值域分块。

  • O(1) 插入,O(√n) 查询第k小:维护整块内元素个数
  • O(√n) 插入,O(1) 查询第k小:采用更复杂的数据结构

分块结合值域并查集

当值域规模不大且需处理区间内特定类别元素的修改查询时,可将分块与值域并查集结合。

典型应用场景:处理带有颜色标记的区间操作。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll MAXN = 300000;
const ll MAXM = 500000;

int n, m, colorCnt;
int blockSize, blockCnt, startPos[MAXN], endPos[MAXN], blockId[MAXN];
ll arr[MAXN], blockSum[MAXN];
int color[MAXN], elementRoot[MAXN], rootNode[MAXN][MAXN];

struct DSU {
    int parent[MAXM];
    int size[MAXM];
    int col[MAXM];
    ll tag[MAXM];
    
    int find(int x) {
        if (!parent[x]) return x;
        auto res = find(parent[x]);
        tag[x] += tag[parent[x]];
        parent[x] = res.id;
        return parent[x];
    }
    
    ll getTag(int x) {
        find(x);
        return tag[x];
    }
};

DSU dsu;

struct Result {
    int id;
    ll value;
};

Result getInfo(int x) {
    int root = dsu.find(elementRoot[x]);
    return {root, dsu.getTag(elementRoot[x])};
}

void processPartialBlock(int l, int r, int fromColor, int toColor, int bid) {
    for (int i = l; i <= r; ++i) {
        auto info = getInfo(i);
        int p = info.id;
        ll accumulated = info.value;
        
        if (dsu.col[p] != fromColor) continue;
        if (--dsu.size[p] == 0) rootNode[bid][dsu.col[p]] = 0;
        
        arr[i] += accumulated + dsu.tag[p];
        
        if (!rootNode[bid][toColor]) {
            rootNode[bid][toColor] = elementRoot[i] = ++dsu.parent[0];
            dsu.parent[dsu.parent[0]] = dsu.tag[dsu.parent[0]] = 0;
            dsu.size[dsu.parent[0]] = 1;
            dsu.col[dsu.parent[0]] = toColor;
        } else {
            elementRoot[i] = rootNode[bid][toColor];
            arr[i] -= dsu.tag[elementRoot[i]];
            ++dsu.size[elementRoot[i]];
        }
    }
}

void updateColor(int l, int r, int fromColor, int toColor) {
    if (fromColor == toColor) return;
    if (blockId[l] == blockId[r]) {
        processPartialBlock(l, r, fromColor, toColor, blockId[l]);
        return;
    }
    processPartialBlock(l, endPos[blockId[l]], fromColor, toColor, blockId[l]);
    processPartialBlock(startPos[blockId[r]], r, fromColor, toColor, blockId[r]);
    
    for (int i = blockId[l] + 1; i < blockId[r]; ++i) {
        int p1 = rootNode[i][fromColor];
        int p2 = rootNode[i][toColor];
        if (!p1) continue;
        if (p2) {
            dsu.size[p2] += dsu.size[p1];
            dsu.parent[p1] = p2;
            dsu.tag[p1] -= dsu.tag[p2];
            rootNode[i][fromColor] = 0;
        } else {
            dsu.col[p1] = toColor;
            swap(rootNode[i][fromColor], rootNode[i][toColor]);
        }
    }
}

void updateValue(int l, int r, int targetColor, ll delta) {
    if (blockId[l] == blockId[r]) {
        for (int i = l; i <= r; ++i) {
            auto info = getInfo(i);
            if (dsu.col[info.id] == targetColor) {
                arr[i] += delta;
                blockSum[blockId[l]] += delta;
            }
        }
        return;
    }
    for (int i = l; i <= endPos[blockId[l]]; ++i) {
        auto info = getInfo(i);
        if (dsu.col[info.id] == targetColor) {
            arr[i] += delta;
            blockSum[blockId[l]] += delta;
        }
    }
    for (int i = startPos[blockId[r]]; i <= r; ++i) {
        auto info = getInfo(i);
        if (dsu.col[info.id] == targetColor) {
            arr[i] += delta;
            blockSum[blockId[r]] += delta;
        }
    }
    for (int i = blockId[l] + 1; i < blockId[r]; ++i) {
        int p = rootNode[i][targetColor];
        if (!p) continue;
        dsu.tag[p] += delta;
        blockSum[i] += dsu.size[p] * delta;
    }
}

ll query(int l, int r) {
    ll ans = 0;
    if (blockId[l] == blockId[r]) {
        for (int i = l; i <= r; ++i) {
            auto info = getInfo(i);
            ans += arr[i] + info.value + dsu.tag[info.id];
        }
        return ans;
    }
    for (int i = l; i <= endPos[blockId[l]]; ++i) {
        auto info = getInfo(i);
        ans += arr[i] + info.value + dsu.tag[info.id];
    }
    for (int i = startPos[blockId[r]]; i <= r; ++i) {
        auto info = getInfo(i);
        ans += arr[i] + info.value + dsu.tag[info.id];
    }
    for (int i = blockId[l] + 1; i < blockId[r]; ++i) ans += blockSum[i];
    return ans;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m >> colorCnt;
    blockSize = sqrt(n);
    blockCnt = ceil(n * 1.0 / blockSize);
    
    for (int i = 1; i <= n; ++i) cin >> arr[i];
    for (int i = 1; i <= n; ++i) cin >> color[i];
    
    for (int i = 1; i <= blockCnt; ++i) {
        startPos[i] = blockSize * (i - 1) + 1;
        endPos[i] = min(blockSize * i, n);
        for (int j = startPos[i]; j <= endPos[i]; ++j) {
            blockId[j] = i;
            blockSum[i] += arr[j];
            if (!rootNode[i][color[j]]) {
                rootNode[i][color[j]] = elementRoot[j] = ++dsu.parent[0];
                dsu.parent[dsu.parent[0]] = dsu.tag[dsu.parent[0]] = 0;
                dsu.size[dsu.parent[0]] = 1;
                dsu.col[dsu.parent[0]] = color[j];
            } else {
                elementRoot[j] = rootNode[i][color[j]];
                ++dsu.size[elementRoot[j]];
            }
        }
    }
    
    for (int i = 1; i <= m; ++i) {
        int opt, l, r, x, y;
        ll k;
        cin >> opt >> l >> r;
        if (opt == 1) {
            cin >> x >> y;
            updateColor(l, r, x, y);
        }
        if (opt == 2) {
            cin >> x >> k;
            updateValue(l, r, x, k);
        }
        if (opt == 3) {
            cout << query(l, r) << '\n';
        }
    }
    return 0;
}

相关文章

Linux crontab 详解

1) crontab 是什么cron 是 Linux 的定时任务守护进程;crontab 是用来编辑/查看“按时间周期执行命令”的表(cron table)。常见两类:用户 crontab:每个用户一份(crontab -e 编辑)系统级 crontab / cron.d:可指定执行用户(/etc/crontab、/etc/cron.d/*)2) crontab 时间...

富文本里可以允许的 HTML 属性

一、所有标签默认允许的安全属性(极少)class        (可选)id           (通常建议禁用)title️ 注意:id 容易被滥用做锚点注入,很多系统直接禁用class 允许的话最好只允许固定前缀(如 editor-*)二、a 标签允许属性<a href="" t...

Mac 安装 Node.js 指南

方法一:通过官网安装包(最简单,适合初学者)如果你只是想快速安装并开始使用,这是最直接的方法。访问 Node.js 官网。页面会显示两个版本:LTS (Recommended For Most Users):长期支持版,最稳定。建议选这个。Current:最新特性版,包含最新功能但可能不够稳定。下载 .pkg 安装包并运行。按照安装向导点击“下一步”即可完成。方法二:使用 Homebrew 安装(...

Dom\HTML_NO_DEFAULT_NS 的副作用:自动加闭合标签

在使用Dom\HTMLDocument时,Dom\HTML_NO_DEFAULT_NS 将禁止在解析过程中设置元素的命名空间, 此设置是为了与DOMDocument向后兼容而存在的。当使用它时,已知的一个副作用就是:自动加闭合标签例如 </img> 为什么会这样?当你使用:Dom\HTML_NO_DEFAULT_NS文档会变成 无命名空间模式,此时内部更接近 XML...

Laravel 事件和监听器创建

在 Laravel 中,使用 Artisan 命令创建 Events(事件) 和 Listeners(监听器) 是非常高效的。你可以通过以下几种方式来实现:1. 手动创建单个 Event如果你只想创建一个事件类,可以使用 make:event 命令:Bashphp artisan make:event UserRegistered执行后,文件将生成在 app/Even...

自定义域名解析神器 dnsmasq

什么是 dnsmasq?dnsmasq 是一个轻量级、功能强大的网络服务工具,专为小型和中等规模网络设计。它是一个综合的网络基础设施解决方案[1]。dnsmasq 能做什么?功能说明应用场景DNS 转发与缓存将 DNS 查询转发到上游服务器(ISP、Google DNS 等),并在本地缓存结果加快 DNS 查询速度,减少外部 DNS 流量本地 DNS解析本地网络设备的主机名,无需编辑&n...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。