子树求和

子树求和是指对一个有根树的每个子树求和。

这里所谓“求和”,不单是指把一堆东西加起来,也包括各式各样信息的汇总。
比如树的每个点上有一些操作,对子树求和可以是把一个子树里的所有操作作用于一个全局数据结构。

子树求和问题通常有三种解法。

  1. 子树合并。
    例如求子树 size 的那种合并法。
  2. 把两个前缀和相减得到一个子树和。
  3. dsu on tree
    一种优化的子树合并。

方法一:子树合并

像求子树 size 那样,可以用直接的子树合并来解决的问题,大家想必都会了。

我们来介绍一个技巧,可以把对路径的询问或操作,转化为子树求和问题。

例题 Max Flow

洛谷P3128

给你一个 个点的树。点从 编号。每个点有个权值,最初都等于

个操作,每个操作给你两个整数

  • 把从 的路径上的每个点(包括点 和点 )的权值加

个操作过后,点的权值的最大值。


把点 的权值记作

我们可以把点 的权值看作一种子树和
也就是说,有一个序列 ,对每个点 都有

把从 的路径上的每个点的权值加 ,对序列 的效果恰是

  • ,给 ,给 ,给

我们把这个技巧称为树上差分

  • 对序列 求子树和结果就是序列 .
  • 所以我们不妨把 看作 在有根树这种结构上的差分序列。

在上述问题中,如果把点权换成边权。把操作改成

  • 把从 的路径上的每条边的权值加

又该如何处理呢?

center

  • 对每个点 ,把连接 和它的父节点的边,即 父边,的权值看作点 的权值。
  • 根节点的父边不存在。
  • 把从 的路径上每条边的权值加 ,也就是对从 的路径上,除了 之外的每个点的权值加
  • 对点权的差分序列 的效果就是:给 ,给 ,给

最近公共祖先

const int maxn = 5e4 + 5;
vector<int> g[maxn];
int num, L[maxn], R[maxn];

bool is_ancestor(int a, int b) { // a 是不是 b 的祖先
    return L[a] <= L[b] && L[b] <= R[a];
}

int anc[maxn][16];

int lca(int u, int v) {
    if (is_ancestor(u, v)) return u;
    if (is_ancestor(v, u)) return v;
    for (int i = 15; i >= 0; i--) {
        if (anc[u][i] && !is_ancestor(anc[u][i], v))
            u = anc[u][i];
    }
    return anc[u][0];
}

void dfs(int u, int p) {
    L[u] = ++num;
    anc[u][0] = p;
    for (int i = 1; i < 16; i++)
        anc[u][i] = anc[anc[u][i - 1]][i - 1];
    for (int v : g[u])
        if (v != p) {
            dfs(v, u);
        }
    R[u] = num;
}
int a[maxn];
void get_sum(int u, int p) { // 求子树和
    for (int v : g[u])
        if (v != p) {
            get_sum(v, u);
            a[u] += a[v];
        }
}

int main() {
    int n, k; cin >> n >> k;
    for (int i = 1; i < n; i++) {
        int u, v; cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1, 0);
    while (k--) {
        int s, t; cin >> s >> t;
        int u = lca(s, t);
        a[s]++; a[t]++;
        a[u]--; a[anc[u][0]]--;
    }
    get_sum(1, 0);
    cout << *max_element(a + 1, a + n + 1) << '\n';
}

例题 运输计划

洛谷P2680

给你一个有 个点的树。点从 编号。
条边()连接点 ,走过这条边需要花费时间

个运输计划。第 个计划()是从点 到点

你可以选一条边把它改造成虫洞,走过虫洞花费的时间是

所有运输计划同时开始。求完成全部计划最少要花多少时间。

暴力

枚举边 ,把边 变成虫洞,计算每条路径的总时间,求最大值。

二分答案

给定总时间的上限

  • 判断是否能通过把一条边改成虫洞,而使得每条路径的总时间都不超过

考虑原本总时间大于 的路径。

  • 改成虫洞的那条边应当是这些路径的公共边。
  • 如果有多条公共边,应当把耗时最多的那条改成虫洞。

问题化为

  • 如何找出原本总时间大于 的那些路径的公共边?

每条边有一个权值,最初都是

  • 设有 条总时间大于 的路径。对每条这样的路径,把上面的每条边的权值加
  • 最后权值等于 的边,就是我们要找的公共边。
long long dist[maxn]; // dist[i]:根到点i的距离
int we[maxn]; //we[i]:点i的父边的长度。
vector<int> postorder; //用非递归的方式计算子树和
void dfs(int u, int p) {
    L[u] = ++num;
    anc[u][0] = p;
    for (int i = 1; i < 19; i++)
        anc[u][i] = anc[anc[u][i - 1]][i - 1];
    for (auto [v, w] : g[u])
        if (v != p) {
            dist[v] = dist[u] + w; 
            dfs(v, u);
            we[v] = w;
        }
    R[u] = num;
    postorder.push_back(u);
}

int a[maxn];
int n, m;
int s[maxn], t[maxn], LCA[maxn];
long long len[maxn];
long long max_len;

bool check(long long k) {
    memset(a, 0, sizeof a);
    int cnt = 0;
    for (int i = 0; i < m; i++)
        if (len[i] > k) {
            cnt++;
            // 把 s[i] 到 t[i] 的路径上的边值加 1
            a[s[i]]++; a[t[i]]++;
            a[LCA[i]] -= 2;
        }
    //计算子树和
    for (int v : postorder)
        a[anc[v][0]] += a[v]; 
    for (int i = 1; i <= n; i++)
        if (a[i] == cnt && max_len - we[i] <= k)
            return true;
    return false;
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cin >> n >> m;
    for (int i = 1; i < n; i++) {
        int u, v, w; cin >> u >> v >> w;
        g[u].push_back({v, w});
        g[v].push_back({u, w});
    }
    dfs(1, 0);
    for (int i = 0; i < m; i++) {
        cin >> s[i] >> t[i];
        LCA[i] = lca(s[i], t[i]);
        len[i] = dist[s[i]] + dist[t[i]] - 2 * dist[LCA[i]];
    }
    max_len = *max_element(len, len + m);
    long long ok = max_len, ng = -1;
    while (ok - ng > 1) {
        long long k = (ok + ng) / 2;
        if (check(k))
            ok = k;
        else
            ng = k;
    }
    cout << ok << '\n';
}

方法二:把两个前缀和相减得到一个子树和

前缀:有根树的先序遍历所得序列的前缀。

例题 Tree Requests

CF570D

给你一个有 个点的有根树。点从 编号。
是根。点 的父节点是 )。每个点上有一个小写英文字母。

回答 个询问。第 个询问给你两个整数 ,问

  • 子树 里,深度等于 的点上的那些字符能否构成一个回文串?

在本题里,约定根的深度为

解法

在 DFS 的过程中,我们维护 个长度为 的 01 序列,

  • :在已经访问过的深度等于 的节点上,a 到 z 出现次数的奇偶性。0 偶,1 奇。

为了回答询问 ,在即将进入子树 时,记录一下 的值,在即将离开子树 时,再记录一下 的值。把这先后两个 01 序列异或,结果就是

  • 子树 里深度等于 的那些点上,每种字符出现次数的奇偶性。

若结果中 1 不超过 个,那些字符就能构成回文串。

const int maxn = 5e5 + 5;

vector<int> g[maxn];
char c[maxn];
vector<pair<int,int>> query[maxn];

bitset<26> p[maxn]; //全局数据结构
bitset<26> ans[maxn];

void dfs(int u, int depth) {
    for (auto [h, id] : query[u])
        ans[id] = a[h];
    p[depth].flip(c[u] - 'a');
    for (int v : g[u])
        dfs(v, depth + 1);
    for (auto [h, id] :  query[u])
        ans[id] ^= p[h];
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n, m; cin >> n >> m;
    for (int i = 2; i <= n; i++) {
        int p; cin >> p;
        g[p].push_back(i);
    }
    for (int i = 1; i <= n; i++)
        cin >> c[i];
    for (int i = 0; i < m; i++) {
        int v, h; cin >> v >> h;
        query[v].push_back({h, i});
    }
    dfs(1, 1);
    for (int i = 0; i < m; i++)
        cout << (ans[i].count() < 2 ? "Yes" : "No") << '\n';
}

上面的解法是离线的,时间

还有一种类似的在线解法。

对每个 ,定义一个有序的列表

  • :深度为 的点的 DFS 序号,从小到大排列构成的列表。

对每个 ,设点 的 DFS 序号是 ,定义

  • :和点 在同一深度且 DFS 序号小于等于 的那些点上,每种字符出现次数的奇偶性。

对于询问 ,设子树 里的点的 DFS 序号范围是 ,在列表 上二分查找,找出子树 里第一个和最后一个深度是 的点的 DFS 序号。

center

bitset<26> s[maxn];
vector<int> order[maxn];
int L[maxn], R[maxn], num;

void dfs(int u, int depth) {
    L[u] = ++num;
    int prev = order[depth].empty() ? 0 : order[depth].back();
    s[num] = s[prev];
    s[num].flip(c[u] - 'a');
    order[depth].push_back(num);
    for (int v : g[u])
        dfs(v, depth + 1);
    R[u] = num;
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    // 输入 ...
    dfs(1, 1);
    for (int i = 0; i < m; i++) {
        int v, h; cin >> v >> h;
        bitset<26> ans;
        auto l = lower_bound(order[h].begin(), order[h].end(), L[v]);
        auto r = upper_bound(order[h].begin(), order[h].end(), R[v]);
        for (auto ptr : {l, r})
            if (ptr != order[h].begin())
                ans ^= s[*(ptr - 1)];
        cout << (ans.count() > 1 ? "No" : "Yes") << '\n';
    }
}

离线解法更好写。

例题 天天爱跑步

洛谷P1600

有一个有 个点的树。点从 编号。

每天有 个人在树上跑步,第 个人的起点是 ,终点是 。每天,所有人在第 秒同时从自己的起点出发,以每秒跑一条边的速度,不间断地跑向各自的终点。

在每个点上都有一个观察员。一个人能被点 上的观察员观察到,当且仅当此人在第 秒恰好到达点

求每个观察员会观察到多少人?

注意:设某人的终点是 ,如果他在第 秒之前到达点 ,那么他会被观察到。

考虑以点 为根的有根树。
把点 )的深度记作

设某人从 出发跑向 。我们把这条路径上的点分成两段

  • 上行:从
  • 下行:从 ,不包括

每个观察员会观察到多少个上行的人?

某人从 跑向 ,上行经过点 ,且被观察员 观察到,当且仅当

我们可以这样看

  • 每个点 有一个(多重)集合 ,最初为空。
  • 对于每一条路径 的上行段的每个点 ,向 里添加一个数

观察员 观察到的上行人数 = 集合 的个数.

每个观察员会观察到多少个下行的人?

某人从 跑向 ,下行经过点 ,且被观察员 观察到,当且仅当

我们可以这样看

  • 每个点 有一个(多重)集合 ,最初为空。
  • 对于每一条路径 的下行段的每个点 ,向 里添加一个数

观察员 观察到的下行人数 = 集合 的个数.

  • 用树上差分来处理这里的「向一条路径上的每个点的集合里添加同一个数」的操作。

  • 这里的路径是特殊的:两个端点是“祖先—后代”关系。

const int maxn = 3e5 + 5;
int num, L[maxn], R[maxn];

bool is_ancestor(int a, int b) { // a 是不是 b 的祖先
    return L[a] <= L[b] && L[b] <= R[a];
}

int anc[maxn][19];
int lca(int u, int v) {
    if (is_ancestor(u, v)) return u;
    if (is_ancestor(v, u)) return v;
    for (int i = 18; i >= 0; i--) {
        if (anc[u][i] && !is_ancestor(anc[u][i], v))
            u = anc[u][i];
    }
    return anc[u][0];
}

vector<int> g[maxn];
int depth[maxn];
void dfs(int u, int p) {
    L[u] = ++num;
    anc[u][0] = p;
    depth[u] = depth[p] + 1;
    for (int i = 1; i < 19; i++)
        anc[u][i] = anc[anc[u][i - 1]][i - 1];
    for (int v : g[u])
        if (v != p) {
            dfs(v, u);
        }
    R[u] = num;
}
int cnt[2 * maxn];
int ans[maxn];
vector<pair<int,int>> op[maxn];
void get_up(int u, int p) {
    int key = w[u] + depth[u];
    int before = cnt[key];
    for (auto [k, type] : op[u])
        cnt[k] += type;
    for (int v : g[u])
        if (v != p)
            get_up(v, u);
    ans[u] += cnt[key] - before;
}

int n, m;
void get_down(int u, int p) {
    int key = w[u] - depth[u] + n;
    int before = cnt[key];
    for (auto [k, type] : op[u]) {
        cnt[k] += type;
    }
    for (int v : g[u])
        if (v != p)
            get_down(v, u);
    ans[u] += cnt[key] - before;
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cin >> n >> m;
    for (int i = 1; i < n; i++) {
        int u, v; cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    for (int i = 1; i <= n; i++)
        cin >> w[i];
    
    dfs(1, 0);

    vector<int> s(m), t(m), LCA(m);
    for (int i = 0; i < m; i++) {
        cin >> s[i] >> t[i];
        LCA[i] = lca(s[i], t[i]);
        ans[LCA[i]] += depth[s[i]] == depth[LCA[i]] + w[LCA[i]];
    }
    for (int i = 0; i < m; i++) {
        int k = depth[s[i]];
        op[s[i]].push_back({k, 1});
        op[LCA[i]].push_back({k, -1});
    }
    get_up(1, 0);

    memset(cnt, 0, sizeof cnt);
    for (int i = 1; i <= n; i++)
        op[i].clear();
    
    for (int i = 0; i < m; i++) {
        int time = depth[s[i]] + depth[t[i]] - 2 * depth[LCA[i]];
        int k = time - depth[t[i]] + n;
        op[t[i]].push_back({k, 1});
        op[LCA[i]].push_back({k, -1}); 
    }
    get_down(1, 0);
    for (int i = 1; i <= n; i++)
        cout << ans[i] << ' ';
    cout << '\n';
}

方法三:DSU on Tree