题目链接:Game on Tree 3
有一棵含有 n n n 个节点的树,节点编号从 1 1 1 到 n n n,根节点为 1 1 1,所有非根节点均有一个正整数权值。根节点上放有一个棋子。T 和 A 两个人正在玩一个回合制游戏。一个回合中:
游戏结束时 T 会获得棋子所在位置的权值的得分。T 想最大化得分,而 A 想最小化得分,问两人在最优策略下 T 最后的得分是多少。
先看官方题解:
由于验证 T 能否至少得到 x x x 分比较容易,我们可以二分他的得分。假设当前 check 的是至少得到 x x x 分的情况,则将树上权值小于 x x x 的节点染成白色,权值大于等于 x x x 的节点染成黑色,然后进行树形 dp:设 d p [ u ] dp[u] dp[u] 表示在以 u u u 为根节点的子树中 A 需要额外染色 d p [ u ] dp[u] dp[u] 次才能使 T 无法走到黑色节点。那么状态转移方程:
d p [ u ] = max ( ∑ d p [ v ] − 1 , 0 ) + [ v a l u ≥ x ] dp[u]=\max(\sum dp[v]-1,0)+[val_u\ge x] dp[u]=max(∑dp[v]−1,0)+[valu≥x]
其中 v v v 为 u u u 的子节点。求和后减一是因为在 T 走下去之前还有一次变颜色的机会。
最后如果 d p [ 1 ] > 0 dp[1]\gt 0 dp[1]>0 说明可以取到大于等于 x x x 的权值。
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int maxn = 2e5 + 5;
vector<int> g[maxn];
int a[maxn], dp[maxn];
void dfs(int u, int f, int x) {
dp[u] = 0;
for (auto v : g[u]) {
if (v == f)
continue;
dfs(v, u, x);
dp[u] += dp[v];
}
dp[u] = max(dp[u] - 1, 0) + (a[u] >= x);
}
void solve() {
int n;
cin >> n;
for (int i = 2; i <= n; ++i) {
cin >> a[i];
}
for (int i = 1, u, v; i < n; ++i) {
cin >> u >> v;
g[u].push_back(v), g[v].push_back(u);
}
int l = 0, r = 1e9, ans = 0;
while (l <= r) {
int mid = (l + r) >> 1;
dfs(1, 0, mid);
if (dp[1] > 0)
l = mid + 1, ans = mid;
else
r = mid - 1;
}
cout << ans << endl;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T = 1;
// cin >> T;
while (T--) {
solve();
}
}
这种染色的技巧在一些数据结构题中也有出现,在不太容易直接计算但比较容易 check 的情况下可以尝试一下。
那么能不能直接求出这个答案呢?我的室友给出了一种更为巧妙的做法:
先考虑树的高度是 1 1 1 的情况,显然 A A A 的最优选择是改变权值最大的那个叶子节点。
如果上面的这个东西是一个子树,那么它就会向父亲的地方输送除去这个权值之外的所有权值。然后又会产生一个改变权值的机会,所以就从这些权值里再删掉一个最大的。这个过程可以用可并堆来维护。
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int maxn = 2e5 + 5;
const ll mod = 998244353;
vector<int> g[maxn];
int a[maxn];
int fa[maxn], ls[maxn], rs[maxn], d[maxn];
int findfa(int x) {
return fa[x] == x ? fa[x] : (fa[x] = findfa(fa[x]));
}
int merge(int x, int y) {
if (!x || !y) {
d[x] = d[y] = 0;
return x + y;
}
if (a[x] < a[y])
swap(x, y);
rs[x] = merge(rs[x], y);
if (d[ls[x]] < d[rs[x]])
swap(ls[x], rs[x]);
d[x] = d[rs[x]] + 1;
return x;
}
int join(int x, int y) {
x = findfa(x), y = findfa(y);
if (x == y)
return x;
fa[x] = fa[y] = merge(x, y);
return fa[x];
}
int pop(int x) {
x = findfa(x);
int t = x, y = rs[x];
x = ls[x];
fa[x] = fa[y] = fa[t] = merge(x, y);
return fa[x];
}
int dfs(int u, int f) {
if (u != 1 && g[u].size() == 1)
return u;
int rt = 0;
for (auto v: g[u]) {
if (v == f)
continue;
int r = dfs(v, u);
if (!rt)
rt = r;
else
rt = join(rt, r);
}
rt = pop(rt);
return join(rt, u);
}
void solve() {
int n;
cin >> n;
for (int i = 2; i <= n; ++i) {
cin >> a[i];
}
for (int i = 1; i <= n; ++i) {
fa[i] = i;
}
for (int i = 1, u, v; i < n; ++i) {
cin >> u >> v;
g[u].push_back(v), g[v].push_back(u);
}
int ans = dfs(1, 0);
cout << a[ans] << endl;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T = 1;
// cin >> T;
while (T--) {
solve();
}
}