题解:
先注意到一定存在k种颜色,切成k个块, 然后要求每个块内的颜色都一样,所以可以发现同一种颜色一定在同一个块内,故任意2个相同颜色的最短路劲上的点的颜色都是该颜色。
我们可以先把任意相同颜色点对的路径上的点的颜色都染成这个颜色。 如果发现存在一个点是已经有颜色的话,那么答案一定为0。
至于怎么颜色, 我们可以暴力往上跑,然后压缩路径,和并查集一样的道理,路过的点都染色且压缩。
这样就完成了第一部分的处理,接下来就是树DP了。
定义dp[ u ][ 0 ] 的含义是 以u为根的子树,且u还未被划入任意一种颜色的方案数。
dp[ u ][ 1 ] 的含义是 以u为根的子树,且u已被划入任意一种颜色的方案数。
那么对于有颜色的点来说,
dp[u][0] = 0;
对于没有颜色的点来说:
代码:
/*code by: zstu wxktime: 2019/03/02Tags: 树DPProblem Link: http://codeforces.com/contest/1118/problem/F2Solve: https://www.cnblogs.com/MingSD/p/10462228.html*/#includeusing namespace std;#define Fopen freopen("_in.txt","r",stdin); freopen("_out.txt","w",stdout);#define LL long long#define ULL unsigned LL#define fi first#define se second#define pb push_back#define lson l,m,rt<<1#define rson m+1,r,rt<<1|1#define lch(x) tr[x].son[0]#define rch(x) tr[x].son[1]#define max3(a,b,c) max(a,max(b,c))#define min3(a,b,c) min(a,min(b,c))typedef pair pll;const int inf = 0x3f3f3f3f;const int _inf = 0xc0c0c0c0;const LL INF = 0x3f3f3f3f3f3f3f3f;const LL _INF = 0xc0c0c0c0c0c0c0c0;const LL mod = 998244353;const int N = 3e5 + 100;int n, k, deep[N], fa[N], col[N];vector vc[N], e1[N];void dfs(int o, int u){ deep[u] = deep[o] + 1; fa[u] = o; for(int v : e1[u]){ if(v == o) continue; dfs(u, v); }}bool Link(int u, int v, int k){ int tu = u, tv = v; while(u != v){ if(deep[u] > deep[v]){ u = fa[u]; if(col[u] && col[u] != k) return false; col[u] = k; } else { v = fa[v]; if(col[v] && col[v] != k) return false; col[v] = k; } } while(tu != u){ int tt = fa[tu]; fa[tu] = u; tu = tt; } while(tv != v){ int tt = fa[tv]; fa[tv] = v; tv = tt; } return true;}LL dp[N][2];LL t1[N], t2[N], t3[N];void DFS(int o, int u){ for(int v : e1[u]){ if(v == o) continue; DFS(u, v); } if(col[u]) { dp[u][0] = 0; dp[u][1] = 1; for(int v : e1[u]){ if(v == o) continue; dp[u][1] = (dp[u][1] * (dp[v][0] + dp[v][1])) % mod; } } else { dp[u][0] = 1; dp[u][1] = 0; int k = e1[u].size(); if(k){ for(int i = 0; i < k; ++i){ if(e1[u][i] == o) t1[i+1] = 1; else { t1[i+1] = dp[e1[u][i]][0] + dp[e1[u][i]][1]; dp[u][0] = dp[u][0] * t1[i+1] % mod; } } t2[0] = t3[k+1] = 1; for(int i = 1; i <= k; ++i) t2[i] = t1[i] * t2[i-1] % mod; for(int i = k; i >= 0; --i) t3[i] = t1[i] * t3[i+1] % mod; for(int i = 0; i < k; ++i){ int v = e1[u][i]; if(v == o) continue; dp[u][1] = (dp[u][1] + dp[v][1] * t2[i] % mod * t3[i+2] % mod) % mod; } } }}void Ac(){ for(int i = 1, v; i <= n; ++i){ scanf("%d", &v); if(v) vc[v].pb(i); col[i] = v; } for(int i = 1, u, v; i < n; ++i){ scanf("%d%d", &u, &v); e1[u].pb(v); e1[v].pb(u); } dfs(0, 1); for(int i = 1; i <= n; ++i){ for(int j = 1; j < vc[i].size(); ++j){ if(!Link(vc[i][0],vc[i][j],i)) { puts("0"); return ; } } } DFS(0, 1); printf("%I64d\n", dp[1][1]);}int main(){ while(~scanf("%d%d", &n, &k)){ Ac(); } return 0;}