子树求和

子树求和是指对一个有根树的每个子树求和。

这里所谓“求和”,不单是指把一堆东西加起来,也包括各式各样信息的汇总。
比如树的每个点上有一些操作,对子树求和可以是把一个子树里的所有操作作用于一个全局数据结构。

子树求和问题通常有三种解法。

  1. 子树合并。
    例如求子树 size 的那种合并法。
  2. 把两个前缀和相减得到一个子树和。
  3. dsu on tree
    一种优化的子树合并。

方法一:子树合并

center

像求子树 size 那样,可以用直接的子树合并来解决的问题,大家想必都会了。

我们来介绍一个技巧,可以把对路径的询问或操作,转化为子树求和问题。

例题 Max Flow

洛谷P3128

给你一个 个点的树。点从 编号。每个点有个权值,最初都等于

个操作,每个操作给你两个整数

  • 把从 的路径上的每个点(包括点 和点 )的权值加

个操作过后,点的权值的最大值。

考虑以点 为根的有根树
我们可以把每个点的权值看作一种子树和

把点 的权值记作 。也就是说,有一个序列 ,对每个点 都有

我们把从 的路径分成两段来看,可见把此路径每个点的权值加 ,对序列 的效果恰是

  • ,给
  • ,给

我们把这个技巧称为树上差分

  • 对序列 求子树和结果就是序列 .
  • 所以我们不妨把 看作 在有根树这种结构上的差分序列。

在上述问题中,如果把点权换成边权。把操作改成

  • 把从 的路径上的每条边的权值加

又该如何处理呢?

center

  • 对每个点 ,把连接 和它的父节点的边,即 父边,的权值看作点 的权值。
  • 根节点的父边不存在。
  • 把从 的路径上每条边的权值加 ,也就是对从 的路径上,除了 之外的每个点的权值加
  • 对点权的差分序列 的效果就是:给 ,给 ,给

最近公共祖先

const int maxn = 5e4 + 5;
vector<int> g[maxn];
int num, L[maxn], R[maxn];

bool is_ancestor(int a, int b) { // a 是不是 b 的祖先
    return L[a] <= L[b] && L[b] <= R[a];
}

int anc[maxn][16];

int lca(int u, int v) {
    if (is_ancestor(u, v)) return u;
    if (is_ancestor(v, u)) return v;
    for (int i = 15; i >= 0; i--) {
        if (anc[u][i] && !is_ancestor(anc[u][i], v))
            u = anc[u][i];
    }
    return anc[u][0];
}

void dfs(int u, int p) {
    L[u] = ++num;
    anc[u][0] = p;
    for (int i = 1; i < 16; i++)
        anc[u][i] = anc[anc[u][i - 1]][i - 1];
    for (int v : g[u])
        if (v != p) {
            dfs(v, u);
        }
    R[u] = num;
}
int a[maxn];
void get_sum(int u, int p) { // 求子树和
    for (int v : g[u])
        if (v != p) {
            get_sum(v, u);
            a[u] += a[v];
        }
}

int main() {
    int n, k; cin >> n >> k;
    for (int i = 1; i < n; i++) {
        int u, v; cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1, 0);
    while (k--) {
        int s, t; cin >> s >> t;
        int u = lca(s, t);
        a[s]++; a[t]++;
        a[u]--; a[anc[u][0]]--;
    }
    get_sum(1, 0);
    cout << *max_element(a + 1, a + n + 1) << '\n';
}

例题 运输计划

洛谷P2680

给你一个有 个点的树。点从 编号。
条边()连接点 ,走过这条边需要花费时间

个运输计划。第 个计划()是从点 到点

你可以选一条边把它改造成虫洞,走过虫洞花费的时间是

所有运输计划同时开始。求完成全部计划最少要花多少时间。

暴力

枚举边 ,把边 变成虫洞,计算每条路径的总时间,求最大值。

二分答案

给定总时间的上限

  • 判断是否能通过把一条边改成虫洞,而使得每条路径的总时间都不超过

考虑原本总时间大于 的路径。

  • 改成虫洞的那条边应当是这些路径的公共边。
  • 如果有多条公共边,应当把耗时最多的那条改成虫洞。

问题化为

  • 如何找出原本总时间大于 的那些路径的公共边?

每条边有一个权值,最初都是

  • 设有 条总时间大于 的路径。对每条这样的路径,把上面的每条边的权值加
  • 最后权值等于 的边,就是我们要找的公共边。
long long dist[maxn]; // dist[i]:根到点i的距离
int we[maxn]; //we[i]:点i的父边的长度。
vector<int> postorder; //用非递归的方式计算子树和
void dfs(int u, int p) {
    L[u] = ++num;
    anc[u][0] = p;
    for (int i = 1; i < 19; i++)
        anc[u][i] = anc[anc[u][i - 1]][i - 1];
    for (auto [v, w] : g[u])
        if (v != p) {
            dist[v] = dist[u] + w; 
            dfs(v, u);
            we[v] = w;
        }
    R[u] = num;
    postorder.push_back(u);
}

int a[maxn];
int n, m;
int s[maxn], t[maxn], LCA[maxn];
long long len[maxn];
long long max_len;

bool check(long long k) {
    memset(a, 0, sizeof a);
    int cnt = 0;
    for (int i = 0; i < m; i++)
        if (len[i] > k) {
            cnt++;
            // 把 s[i] 到 t[i] 的路径上的边值加 1
            a[s[i]]++; a[t[i]]++;
            a[LCA[i]] -= 2;
        }
    //计算子树和
    for (int v : postorder)
        a[anc[v][0]] += a[v]; 
    for (int i = 1; i <= n; i++)
        if (a[i] == cnt && max_len - we[i] <= k)
            return true;
    return false;
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cin >> n >> m;
    for (int i = 1; i < n; i++) {
        int u, v, w; cin >> u >> v >> w;
        g[u].push_back({v, w});
        g[v].push_back({u, w});
    }
    dfs(1, 0);
    for (int i = 0; i < m; i++) {
        cin >> s[i] >> t[i];
        LCA[i] = lca(s[i], t[i]);
        len[i] = dist[s[i]] + dist[t[i]] - 2 * dist[LCA[i]];
    }
    max_len = *max_element(len, len + m);
    long long ok = max_len, ng = -1;
    while (ok - ng > 1) {
        long long k = (ok + ng) / 2;
        if (check(k))
            ok = k;
        else
            ng = k;
    }
    cout << ok << '\n';
}

方法二:把两个前缀和相减得到一个子树和

前缀:有根树的先序遍历所得序列的前缀。
center

例题 子树颜色询问 2

洛谷U639895

给你一个有 个点的有根树。点从 编号。点 是根。
每个点有一个颜色,点 的颜色是

回答 个询问。第 个询问给你两个整数 ,问

  • 子树 里有多少个颜色是 的点?


在 DFS 的过程中,我们维护一个数组

  • :访问过的点中颜色是 的有多少个。

要知道子树 里有多少个点颜色是 ,我们可以这样做

  • 在即将进入子树 时,记下 的值。
  • 在即将离开子树 时,记下 的值。
  • 后一个值减前一个值,结果就是子树 里颜色是 的点的数量。
const int maxn = 2e5 + 5;
vector<int> g[maxn];
int a[maxn];
vector<pair<int,int>> query[maxn];

int cnt[maxn];
int ans[maxn];

void dfs(int u, int p) {
    for (auto [c, i] : query[u])
        ans[i] = cnt[c];
    cnt[a[u]]++;
    for (int v : g[u])
        if (v != p)
            dfs(v, u);
    for (auto [c, i] : query[u])
        ans[i] = cnt[c] - ans[i];
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n; cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    for (int i = 0; i < n - 1; i++) {
        int u, v; cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    int m; cin >> m;
    for (int i = 0; i < m; i++) {
        int u, c; cin >> u >> c;
        query[u].push_back({c, i});
    }
    dfs(1, 0);
    for (int i = 0; i < m; i++)
        cout << ans[i] << '\n';
}

上面的解法是离线

  • 先读入所有询问,再处理询问,最后输出每个询问的答案。

现在来考虑在线的解法

  • 先做一些预处理,然后回答询问。
    每读入一个询问,输出答案,再读入下一个询问。

预处理

求 DFS 序号。对每个点 ,定义

  • :点 的 DFS 序号
  • :子树 里最后一个被 DFS 访问的点的 DFS 序号。

把所有点按颜色分类。对每个颜色 ,定义

  • :颜色是 的点的 DFS 序号,从小到大排列,构成的列表。

回答询问

我们知道,点 在子树
所以,子树 里颜色是 的点的数量 = 里值在 之间的 DFS 序号的数量。
后者可以通过在有序的列表 上做两次二分查找来得到。

const int maxn = 2e5 + 5;
vector<int> g[maxn];
int a[maxn];

vector<int> V[maxn];
int L[maxn], R[maxn], num;

void dfs(int u, int p) {
    L[u] = ++num;
    V[a[u]].push_back(L[u]);
    for (int v : g[u])
        if (v != p)
            dfs(v, u);
    R[u] = num;
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n; cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    for (int i = 0; i < n - 1; i++) {
        int u, v; cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1, 0);
    int m; cin >> m;
    for (int i = 0; i < m; i++) {
        int u, c; cin >> u >> c;
        cout <<
        upper_bound(V[c].begin(), V[c].end(), R[u])
        - lower_bound(V[c].begin(), V[c].end(), L[u])
         << '\n';
    }
}

例题 Tree Requests

CF570D

给你一个有 个点的有根树。点从 编号。
是根。点 的父节点是 )。每个点上有一个小写英文字母。

回答 个询问。第 个询问给你两个整数 ,问

  • 子树 里,深度等于 的点上的那些字符,可任意排列,能否构成一个回文串?

在本题里,约定根的深度是

解法

在 DFS 的过程中,我们维护 个长度为 的 01 序列,

  • :在已经访问过的深度等于 的节点上,a 到 z 出现次数的奇偶性。0 偶,1 奇。

为了回答询问 ,在即将进入子树 时,记录一下 的值,在即将离开子树 时,再记录一下 的值。把这先后两个 01 序列异或,结果就是

  • 子树 里深度等于 的那些点上,每种字符出现次数的奇偶性。

若结果中 1 不超过 个,那些字符就能被排成回文串。

const int maxn = 5e5 + 5;

vector<int> g[maxn];
char c[maxn];
vector<pair<int,int>> query[maxn];

bitset<26> p[maxn]; //全局数据结构
bitset<26> ans[maxn];

void dfs(int u, int depth) {
    for (auto [h, id] : query[u])
        ans[id] = a[h];
    p[depth].flip(c[u] - 'a');
    for (int v : g[u])
        dfs(v, depth + 1);
    for (auto [h, id] :  query[u])
        ans[id] ^= p[h];
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n, m; cin >> n >> m;
    for (int i = 2; i <= n; i++) {
        int p; cin >> p;
        g[p].push_back(i);
    }
    for (int i = 1; i <= n; i++)
        cin >> c[i];
    for (int i = 0; i < m; i++) {
        int v, h; cin >> v >> h;
        query[v].push_back({h, i});
    }
    dfs(1, 1);
    for (int i = 0; i < m; i++)
        cout << 
        (ans[i].count() < 2 ? "Yes" : "No")
        << '\n';
}

上面的解法是离线的。现在来考虑在线解法。

预处理

求 DFS 序号。对每个点 ,定义

  • :点 的 DFS 序号。
  • :子树 里最后一个被 DFS 访问的点的 DFS 序号。

把所有点按深度分组。对每个深度 ,定义

  • :深度为 的点的 DFS 序号,从小到大排列,构成的列表。

对每一组点,求前缀和。这里介绍两种实现方法。

方法一:对每个深度 ,定义序列 表示列表 的长度)

  • :一个长为 的 01 序列,表示 里的前 个点上,每种字符出现次数的奇偶性。

方法二:定义序列

  • 设 DFS 序号等于 的点是 ,点 的深度是
  • :一个长为 的 01 序列,表示在深度等于 且 DFS 序号不超过 的那些点上,每种字符出现次数的奇偶性。

回答询问

对于询问 ,子树 里的点的 DFS 序号范围是 ,据此在 二分查找

如果采用上述方法一来表示前缀和,设 中落在 内的 DFS 序号的下标范围是 ,那么

  • 子树 里深度是 的那些点上,每种字符出现次数的奇偶性就是

如果采用上述方法二来表示前缀和,设 中最后一个小于 的 DFS 序号是 ,最后一个不大于 的 DFS 序号是 ,那么

  • 子树 里深度是 的那些点上,每种字符出现次数的奇偶性就是

center

下列代码是采用上述方法二来表示前缀和。

bitset<26> s[maxn];
vector<int> V[maxn];
int L[maxn], R[maxn], num;

void dfs(int u, int depth) {
    L[u] = ++num;
    s[num] = s[V[depth.back()]];
    s[num].flip(c[u] - 'a');
    V[depth].push_back(num);
    for (int v : g[u])
        dfs(v, depth + 1);
    R[u] = num;
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    // 输入 ...
    for (int i = 1; i <= n; i++)
        V[i].push_back(0);
    dfs(1, 1);
    for (int i = 0; i < m; i++) {
        int v, h; cin >> v >> h;
        bitset<26> ans;
        auto l = lower_bound(V[h].begin(), V[h].end(), L[v]);
        auto r = upper_bound(V[h].begin(), V[h].end(), R[v]);
        bitset<26> ans = s[*(r - 1)] ^ s[*(l - 1)];
        cout << (ans.count() > 1 ? "No" : "Yes") << '\n';
    }
}

上面代码里的 bitset<26> 也可换为 int

int s[maxn];
vector<int> V[maxn];
int L[maxn], R[maxn], num;

void dfs(int u, int depth) {
    L[u] = ++num;
    s[num] = s[V[depth.back()]] ^ 1 << (c[u] - 'a');
    V[depth].push_back(num);
    for (int v : g[u])
        dfs(v, depth + 1);
    R[u] = num;
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    // 输入 ...
    for (int i = 1; i <= n; i++)
        V[i].push_back(0);
    dfs(1, 1);
    for (int i = 0; i < m; i++) {
        int v, h; cin >> v >> h;
        auto l = lower_bound(V[h].begin(), V[h].end(), L[v]);
        auto r = upper_bound(V[h].begin(), V[h].end(), R[v]);
        int ans = s[*(r - 1)] ^ s[*(l - 1)];
        cout << (ans & (ans - 1) ? "No" : "Yes") << '\n';
    }
}

离线解法的优点

  • 通常,离线解法更好写。
  • 有些问题并没有简单的在线解法。

例题 天天爱跑步

洛谷P1600

有一个有 个点的树。点从 编号。

每天有 个人在树上跑步,第 个人的起点是 ,终点是 。每天,所有人在第 秒同时从自己的起点出发,以每秒跑一条边的速度,不间断地跑向各自的终点。

在每个点上都有一个观察员。一个人能被点 上的观察员观察到,当且仅当此人在第 秒恰好到达点

求每个观察员会观察到多少人?

注意:设某人的终点是 ,如果他在第 秒之前到达点 ,那么他会被观察到。

考虑以点 为根的有根树

设某人从 出发跑向 。我们把路径 上的点分成两段

  • 上行:从
  • 下行:从 ,不包括

每个观察员会观察到多少个上行的人?

某人从 跑向 ,上行经过点 ,且被观察员 观察到,当且仅当

我们可以这样看

  • 每个点 有一个(多重)集合 ,最初为空。
  • 对于每一条路径 的上行段的每个点 ,向 里添加一个数

观察员 观察到的上行人数 = 集合 的个数.

每个观察员会观察到多少个下行的人?

某人从 跑向 ,下行经过点 ,且被观察员 观察到,当且仅当

我们可以这样看

  • 每个点 有一个(多重)集合 ,最初为空。
  • 对于每一条路径 的下行段的每个点 ,向 里添加一个数

观察员 观察到的下行人数 = 集合 的个数.


是树上两点, 的祖先。
向路径 上每个点 的集合 里添加一个数 ,在差分序列上的效果就是

  • 在点 处加一个「添加一个 」的操作,
  • 在点 的父节点处加一个「删除一个 」的操作。

然后,对每个点

  • 求集合 就化为对子树 (里的操作)求和,
  • 我们想要知道 里有多少个
const int maxn = 3e5 + 5;
int num, L[maxn], R[maxn];

bool is_ancestor(int a, int b) { // a 是不是 b 的祖先
    return L[a] <= L[b] && L[b] <= R[a];
}

int anc[maxn][19];
int lca(int u, int v) {
    if (is_ancestor(u, v)) return u;
    if (is_ancestor(v, u)) return v;
    for (int i = 18; i >= 0; i--)
        if (anc[u][i] && !is_ancestor(anc[u][i], v))
            u = anc[u][i];
    return anc[u][0];
}

vector<int> g[maxn];
int depth[maxn];
void dfs(int u, int p) {
    L[u] = ++num;
    anc[u][0] = p;
    depth[u] = depth[p] + 1;
    for (int i = 1; i < 19; i++)
        anc[u][i] = anc[anc[u][i - 1]][i - 1];
    for (int v : g[u])
        if (v != p)
            dfs(v, u);
    R[u] = num;
}
int w[maxn];
int cnt[2 * maxn];
int ans[maxn];
vector<pair<int,int>> op_up[maxn], op_down[maxn];
void get_up(int u, int p) {
    int key = w[u] + depth[u];
    int before = cnt[key];
    for (auto [x, delta] : op_up[u])
        cnt[x] += delta;
    for (int v : g[u])
        if (v != p)
            get_up(v, u);
    ans[u] += cnt[key] - before;
}

void get_down(int u, int p) {
    int key = w[u] - depth[u] + maxn;
    int before = cnt[key];
    for (auto [x, delta] : op_down[u])
        cnt[x + maxn] += delta;
    for (int v : g[u])
        if (v != p)
            get_down(v, u);
    ans[u] += cnt[key] - before;
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n, m; cin >> n >> m;
    for (int i = 1; i < n; i++) {
        int u, v; cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    for (int i = 1; i <= n; i++)
        cin >> w[i];

    dfs(1, 0);

    for (int i = 0; i < m; i++) {
        int s, t; cin >> s >> t;
        int u = lca(s, t);
        op_up[s].push_back({depth[s], 1});
        op_up[anc[u][0]].push_back({depth[s], -1});
        op_down[t].push_back({depth[s] - 2 * depth[u], 1});
        op_down[u].push_back({depth[s] - 2 * depth[u], -1});
    }

    get_up(1, 0);
    memset(cnt, 0, sizeof cnt);
    get_down(1, 0);

    for (int i = 1; i <= n; i++)
        cout << ans[i] << ' ';
    cout << '\n';
}

方法三:DSU on Tree

有些子树求和问题不适合用前两种方法解决。
其中一类问题是这样的

  • 给我们一个有根树。每个点上有一条信息。
  • 我们需要对每个子树求和,也就是把子树的每个点的信息汇总。
  • 我们想要的子树和可以这样来获得
    • 准备一个空的数据结构
    • 把一个子树里的信息逐条的加到 里。
    • 利用 查询子树和。

例题 树上数颜色

洛谷U41492

给你一个有 个点的有根树。点从 编号。点 是根。每个点有个颜色,点 的颜色是

回答 个询问。每个询问给你一个点 ,问子树 里有多少种不同的颜色。

对每个点 ,求子树 里有多少种颜色。

朴素的解法:

  • 使用一个长度为 的数组 。最初每个 都等于
  • 对于子树 里的每个点 ,给 加一。同时维护出现的颜色种类数
  • 记录子树 的答案,也就是最后的
  • 数组清空:对于子树 里的每个点 ,把 置为 。把 置为

时间是

在这个问题里,一个子树里的信息量可以用它的 size(即点的数量)来衡量。

对上面的朴素解法,有这么一种优化。

第一步:预处理

  • 求出每个子树的大小。
  • 对每个点 ,找出 的孩子中 size 最大的那个,如果多个,任取一个。设它是
    我们把 称为 重孩子,把子树 称为 重子树。其余孩子称为轻孩子,其余子树称为轻子树

center

我们把 数组和 设置为全局变量,把二者统称为全局数据结构全局状态

第二步:从根节点开始对树做一次 DFS。

dfs(u):
    对于 u 的每个轻孩子 v:
        dfs(v)
        清空全局状态
    dfs(heavy[u])
    把子树 u 的除了重子树之外的部分加入全局状态

DFS 的另一种写法

dfs(u):
    对于 u 的每个轻孩子 v:
        dfs(v)
    dfs(heavy[u])
    把子树 u 的除了重子树之外的部分加入全局状态
    if (u 不是重孩子)
        清空全局状态
const int maxn = 1e5 + 5;
vector<int> g[maxn];
int sz[maxn], heavy_child[maxn];

void get_size(int u, int p) {
    sz[u] = 1;
    for (int v : g[u])
        if (v != p) {
            get_size(v, u);
            sz[u] += sz[v];
            if (sz[heavy_child[u]] < sz[v])
                heavy_child[u] = v;
        }
}
int cnt[maxn], nc; //全局数据结构
int col[maxn];
int ans[maxn];

int preorder[maxn], num;

void dfs(int u, int p, bool keep) {
    int l = num;
    preorder[num++] = u;
    for (int v : g[u])
        if (v != p && v != heavy_child[u])
            dfs(v, u, false);
    int r = num;
    if (heavy_child[u])
        dfs(heavy_child[u], u, true);
    for (int i = l; i < r; i++)
        nc += ++cnt[col[preorder[i]]] == 1;
    ans[u] = nc;
    if (!keep) {
        for (int i = l; i < num; i++)
            cnt[col[preorder[i]]] = 0;
        nc = 0;
    }
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n; cin >> n;
    for (int i = 0; i < n - 1; i++) {
        int u, v; cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    for (int i = 1; i <= n; i++)
        cin >> col[i];
    get_size(1, 0);
    dfs(1, 0, true);
    int m; cin >> m;
    while (m--) {
        int u; cin >> u;
        cout << ans[u] << '\n';
    }
}