考虑每增加一条边都会把路径上的边双都连成一个大边双,考虑合并 x x x和 y = f a x y = fa_x y=fax 这两个边双的贡献,分类讨论:
第一个的贡献即为
A
s
z
x
3
A_{sz_x}^3
Aszx3
第二个的贡献即为
2
A
s
z
x
2
(
n
−
s
z
x
)
2A_{sz_x}^2(n - sz_x)
2Aszx2(n−szx)
第三个比较麻烦:
具体统计方法见 C o d e Code Code。
并查集维护边双联通分量即可,时间复杂度 O ( n l g n ) O(nlgn) O(nlgn)。
#include <bits/stdc++.h>
using namespace std;
template<typename T> inline bool upmin(T &x, T y) { return y < x ? x = y, 1 : 0; }
template<typename T> inline bool upmax(T &x, T y) { return x < y ? x = y, 1 : 0; }
#define MP(A,B) make_pair(A,B)
#define PB(A) push_back(A)
#define SIZE(A) ((int)A.size())
#define LEN(A) ((int)A.length())
#define FOR(i,a,b) for(int i=(a);i<(b);++i)
#define fi first
#define se second
#define int ll
typedef long long ll;
typedef unsigned long long ull;
typedef long double lod;
typedef pair<int, int> PR;
typedef vector<int> VI;
const lod eps = 1e-11;
const lod pi = acos(-1);
const int mods = 998244353;
const int oo = 1 << 30;
const ll loo = 1ll << 62;
const int MAXN = 600005;
const int INF = 0x3f3f3f3f;//1061109567
/*--------------------------------------------------------------------*/
inline int read() {
int f = 1, x = 0; char c = getchar();
while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); }
while (c >= '0' && c <= '9') { x = (x << 3) + (x << 1) + (c ^ 48); c = getchar(); }
return x * f;
}
vector<int> e[MAXN];
int fa[MAXN], f[MAXN], dep[MAXN];
ll sz[MAXN], num[MAXN], g[MAXN], h[MAXN], ans = 0, n;
int find(int x) { return f[x] == x ? f[x] : f[x] = find(f[x]); }
void dfs(int x, int father) {
sz[x] = 0, fa[x] = father, dep[x] = dep[father] + 1;
for (auto v : e[x]) if (v != father) dfs(v, x);
for (auto v : e[x]) {
if (v == father) continue;
ans += sz[x] * sz[v] * 2;
g[x] += g[v] + sz[v];
h[x] += sz[v] * sz[v];
sz[x] += sz[v];
}
++ sz[x];
h[x] += (n - sz[x]) * (n - sz[x]);
ans += (n - sz[x]) * (sz[x] - 1) * 2;
}
void merge(int x, int y) {
ans -= num[x] * (num[x] - 1) * (num[x] - 2); //part 1 x
ans -= num[y] * (num[y] - 1) * (num[y] - 2); //part 1 y
ans -= num[x] * (num[x] - 1) * (n - num[x]) * 2; //part 2 x
ans -= num[y] * (num[y] - 1) * (n - num[y]) * 2; //part 2 y
ans -= (sz[x] - num[x]) * num[x] * num[y] * 2 + (n - sz[x] - num[y]) * num[x] * num[y] * 2; //part 3.1
ans += num[y] * ((sz[x] - num[x]) * (sz[x] - num[x]) - (h[x] - (n - sz[x]) * (n - sz[x]))); //part 3.2 x
ans += num[x] * ((n - sz[x] - num[y]) * (n - sz[x] - num[y]) - (h[y] - sz[x] * sz[x])); //part 3.2 y
f[x] = y, num[y] += num[x], h[y] += h[x] - sz[x] * sz[x] - (n - sz[x]) * (n - sz[x]);
ans += num[y] * (num[y] - 1) * (num[y] - 2); //part 1 new
ans += num[y] * (num[y] - 1) * (n - num[y]) * 2; //part 2 new
}
signed main() {
n = read();
for (int i = 1, u, v; i < n ; ++ i) u = read(), v = read(), e[u].PB(v), e[v].PB(u);
for (int i = 1; i <= n ; ++ i) f[i] = i, num[i] = 1;
dfs(1, 0);
printf("%lld\n", ans);
int Case = read();
while (Case --) {
int u = read(), v = read(), U = find(u), V = find(v);
while (U != V) {
if (dep[U] < dep[V]) swap(U, V);
merge(U, find(fa[U]));
U = find(U);
}
printf("%lld\n", ans);
}
return 0;
}