当前位置: 首页 > 面试经验 >

2024/5/6携程笔试软件开发岗位第四题题解 换根树形DP

优质
小牛编辑
75浏览
2024-05-06

2024/5/6携程笔试软件开发岗位第四题题解 换根树形DP

​题目描述:

游游拿到了一棵树,其中每个节点上有一个数字('0'~'9')。

现在游游定义f(i)为:以i号节点为起点时,取一条路径,上面所有数字拼起来是3的倍数的方案数。

现在小红希望你求出f(1)到f(n)的值,你能帮帮她吗?注:前导零也是合法的。

更好的观看体验请移步:https://blog.csdn.net/qq_67243927/article/details/138507852?spm=1001.2014.3001.5502

题解:

暴力:从每个根开始暴力,发现会超时

正解:树形DP+换根

注意到题目要求的是拼合为3的倍数,根据3的倍数的性质,我们只需要对每位数求和看是否为3的倍数即可,动态规划状态设定如下:

tmp[i][k]表示以i为根的路径拼和模3后为k的方案数 k取0到2 

例如tmp[1][0]代表以i根的路径拼和模3后为0(3的倍数)的方案数 

下面是状态转移公式: x是根,v是儿子

		if (a[x] % 3 == 2) {
			tmp[x][0] += tmp[v][1];
			tmp[x][1] += tmp[v][2];
			tmp[x][2] += tmp[v][0];
		}
		if (a[x] % 3 == 1) {
			tmp[x][0] += tmp[v][2];
			tmp[x][1] += tmp[v][0];
			tmp[x][2] += tmp[v][1];
		}
		if (a[x] % 3 == 0) {
			tmp[x][0] += tmp[v][0];
			tmp[x][1] += tmp[v][1];
			tmp[x][2] += tmp[v][2];
		}

考虑换根

当要把根从x换到v的时候,注意到我们只需要将v对x的贡献从x中剔除,tmp[x][k]满足换根为v后的情况,此时x成为v的儿子,我们只需要再将x的贡献加到v中即可算出换根后的tmp[v][k],然后遍历换根即可。

代码如下:

//
#include <bits/stdc++.h>
using namespace std;
#define maxn 200111
#define ll long long
int n;
ll a[maxn];
ll tmp[maxn][4];
vector<ll> g[maxn];
ll f[maxn];
int mark[maxn];

void dfs(int x, int fa) {
	tmp[x][a[x] % 3] = 1; //自己单独为一个路径的情况

	for (int v : g[x]) {
		if (v == fa)
			continue;

		dfs(v, x); //遍历树
		if (a[x] % 3 == 2) {
			tmp[x][0] += tmp[v][1];
			tmp[x][1] += tmp[v][2];
			tmp[x][2] += tmp[v][0];
		}
		if (a[x] % 3 == 1) {
			tmp[x][0] += tmp[v][2];
			tmp[x][1] += tmp[v][0];
			tmp[x][2] += tmp[v][1];
		}
		if (a[x] % 3 == 0) {
			tmp[x][0] += tmp[v][0];
			tmp[x][1] += tmp[v][1];
			tmp[x][2] += tmp[v][2];
		}
	}

}

void dfs2(int x) {
	for (int v : g[x]) {
		if (mark[v])
			continue;
		if (!mark[v]) {
			mark[v] = 1;
            //注意次数算贡献一定要将儿子v的tmp值临时存储,因为遍历v之后儿子v的tmp值会变
			int r0 = tmp[v][0]; 
			int r1 = tmp[v][1];
			int r2 = tmp[v][2];
			if (a[x] % 3 == 2) {
				tmp[x][0] -= tmp[v][1];
				tmp[x][1] -= tmp[v][2];
				tmp[x][2] -= tmp[v][0];
			}
			if (a[x] % 3 == 1) {
				tmp[x][0] -= tmp[v][2];
				tmp[x][1] -= tmp[v][0];
				tmp[x][2] -= tmp[v][1];
			}
			if (a[x] % 3 == 0) {
				tmp[x][0] -= tmp[v][0];
				tmp[x][1] -= tmp[v][1];
				tmp[x][2] -= tmp[v][2];
			}

			if (a[v] % 3 == 2) {
				tmp[v][0] += tmp[x][1];
				tmp[v][1] += tmp[x][2];
				tmp[v][2] += tmp[x][0];
			}
			if (a[v] % 3 == 1) {
				tmp[v][0] += tmp[x][2];
				tmp[v][1] += tmp[x][0];
				tmp[v][2] += tmp[x][1];
			}
			if (a[v] % 3 == 0) {
				tmp[v][0] += tmp[x][0];
				tmp[v][1] += tmp[x][1];
				tmp[v][2] += tmp[x][2];
			}
			f[v] = tmp[v][0];
			dfs2(v);
            //回溯
			if (a[x] % 3 == 2) {
				tmp[x][0] += r1;
				tmp[x][1] += r2;
				tmp[x][2] += r0;
			}
			if (a[x] % 3 == 1) {
				tmp[x][0] += r2;
				tmp[x][1] += r0;
				tmp[x][2] += r1;
			}
			if (a[x] % 3 == 0) {
				tmp[x][0] += r0;
				tmp[x][1] += r1;
				tmp[x][2] += r2;
			}

		}
	}
}

int main() {
	cin >> n;
	for (int i = 1; i <= n; i++) {
		cin >> a[i];
	}
	for (int i = 1; i < n; i++) {
		int u, v;
		cin >> u >> v;
		g[u].push_back(v);
		g[v].push_back(u);
	}
    //第一次以1为根遍历,然后再考虑换根
	dfs(1, 0);
	mark[1] = 1;
	f[1] = tmp[1][0];
	for (int v : g[1]) {
        //注意次数算贡献一定要将儿子v的tmp值临时存储,因为遍历v之后儿子v的tmp值会变
		int r0 = tmp[v][0];
		int r1 = tmp[v][1];
		int r2 = tmp[v][2];

		if (!mark[v]) {
			mark[v] = 1;
			if (a[1] % 3 == 2) {
				tmp[1][0] -= tmp[v][1];
				tmp[1][1] -= tmp[v][2];
				tmp[1][2] -= tmp[v][0];
			}
			if (a[1] % 3 == 1) {
				tmp[1][0] -= tmp[v][2];
				tmp[1][1] -= tmp[v][0];
				tmp[1][2] -= tmp[v][1];
			}
			if (a[1] % 3 == 0) {
				tmp[1][0] -= tmp[v][0];
				tmp[1][1] -= tmp[v][1];
				tmp[1][2] -= tmp[v][2];
			}

			if (a[v] % 3 == 2) {
				tmp[v][0] += tmp[1][1];
				tmp[v][1] += tmp[1][2];
				tmp[v][2] += tmp[1][0];
			}
			if (a[v] % 3 == 1) {
				tmp[v][0] += tmp[1][2];
				tmp[v][1] += tmp[1][0];
				tmp[v][2] += tmp[1][1];
			}
			if (a[v] % 3 == 0) {
				tmp[v][0] += tmp[1][0];
				tmp[v][1] += tmp[1][1];
				tmp[v][2] += tmp[1][2];
			}
			f[v] = tmp[v][0];
			dfs2(v);
			if (a[1] % 3 == 2) {
				tmp[1][0] += r1;
				tmp[1][1] += r2;
				tmp[1][2] += r0;
			}
			if (a[1] % 3 == 1) {
				tmp[1][0] += r2;
				tmp[1][1] += r0;
				tmp[1][2] += r1;
			}
			if (a[1] % 3 == 0) {
				tmp[1][0] += r0;
				tmp[1][1] += r1;
				tmp[1][2] += r2;
			}
		}
	}
    //输出结果
	for (int i = 1; i <= n; i++) {
		cout << f[i] << endl;
	}
}




时间复杂度O(n)

#携程笔试题#
 类似资料: