当前位置: 首页 > 工具软件 > Genghis > 使用案例 >

HDU 4126 Genghis Khan the Conqueror (树形DP + MST)

易扬
2023-12-01

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4126

 

题意:给n个点,m条边,每条边权值c,现在要使这n个点连通。现在已知某条边要发生突变,再给q个三元组,每个三元组(a,b,c),(a,b)表示图中可能发生突变的边,该边一定是图中的边。c表示该边新的权值,c只可能比原来的权值大。给的q条边发生突变的概率是一样的。求突变后连通n个点最小代价期望值。

 

思路:上次就看到了这个题,这次又好好做了一下,确实是很好的题,思路来自:http://blog.csdn.net/ophunter_lcm/article/details/12030593

 

跑出MST后,对其上的边进行dfs,每次都是根据割边枚举含i和不含i的子树之间的最短距离,时间复杂度为O(N^2)

 

 

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <functional>
#include <utility>
#include <cmath>
#include <cstring>

using namespace std;

const int maxn = 3010;
const int maxm = 3010 * 3010;
const int inf = 0x3f3f3f3f;

int n, m, q;
int d[maxn][maxn];
int used[maxn][maxn];
int dp[maxn][maxn];

vector <int> g[maxn];

struct edge
{
    int from, to, w, nxt;
    edge() {}
    edge(int from, int to, int w) : from(from), to(to), w(w) {}
    bool operator < (const edge & rhs) const
    {
        return w > rhs.w;
    }
};

struct Prim
{
    int head[maxn], dis[maxn], vis[maxn];
    int dep[maxn], pre[maxn], cnt;
    int n, sum;
    edge e[maxm];

    void init(int n)
    {
        this -> n = n;
        cnt = 0;
        sum = 0;
        memset(head, -1, sizeof(head));
    }

    void add(int u, int v, int w)
    {
        e[cnt].from = u;
        e[cnt].to = v;
        e[cnt].w = w;
        e[cnt].nxt = head[u];
        head[u] = cnt++;
    }

    void prim(int s)
    {
        priority_queue <edge> que;
        memset(dis, inf, sizeof(dis));
        memset(vis, 0, sizeof(vis));
        memset(dep, 0, sizeof(dep));
        memset(pre, -1, sizeof(pre));
        vis[s] = 1;
        dis[s] = 0;
        for(int i = head[s]; ~i; i = e[i].nxt)
        {
            int v = e[i].to;
            dis[v] = e[i].w;
            que.push(edge(s, v, e[i].w));
        }
        while(!que.empty())
        {
            edge now = que.top();
            que.pop();
            int u = now.to;
            if(vis[u])
                continue;
            vis[u] = 1;
            used[now.from][now.to] = used[now.to][now.from] = 1;        //这条边是不是MST上的边
            pre[u] = now.from;
            dep[u] = dep[now.from] + 1;
            dis[u]  = now.w;
            sum += now.w;
            for(int i = head[u]; ~i; i = e[i].nxt)
            {
                int v = e[i].to;
                int w = e[i].w;
                if(!vis[v] && dis[v] > w)
                {
                    dis[v] = e[i].w;
                    que.push(edge(u, v, w));
                }
            }
        }
    }
} mst;

void init(int n)
{
    memset(used, 0, sizeof(used));
    memset(d, inf, sizeof(d));
    memset(dp, inf, sizeof(dp));
    for(int i = 0; i <= n; i++)
    {
        d[i][i] = 0;
        g[i].clear();
    }
}

int dfs(int cur, int u, int fa)     //从cur点来更新cur所在子树和另一子树的最短距离
{
    int ans = inf;
    for(int i = 0; i < g[u].size(); i++)
    {
        int v = g[u][i];
        if(v == fa) continue;
        int tmp = dfs(cur, v, u);       //从cur点来更新以u, v为割边的两子树最短距离
        dp[u][v] = dp[v][u] = min(dp[u][v], tmp);
        ans = min(tmp, ans);        //以fa, u为割边的两子树最短距离
    }
    if(fa != cur)                   //生成树边不更新,因为是以该边为割边
        ans = min(ans, d[u][cur]);
    return ans;
}

int main()
{
    while(~scanf("%d%d", &n, &m) && (n + m))
    {
        mst.init(n);
        init(n);
        for(int i = 0; i < m; i++)
        {
            int u, v, w;
            scanf("%d%d%d", &u, &v, &w);
            mst.add(u, v, w);
            mst.add(v, u, w);
            d[u][v] = d[v][u] = w;
        }
        int s = 0;
        mst.prim(s);
        int ans = 0;
        int sum = mst.sum;
        for(int i = 0; i < n; i++)
            for(int j = 0; j < n; j++)
                if(used[i][j])
                    g[i].push_back(j);
        for(int i = 0; i < n; i++)
            dfs(i, i, -1);
        scanf("%d", &q);
        for(int i = 0; i < q; i++)
        {
            int u, v, w;
            scanf("%d%d%d", &u, &v, &w);
            if(used[u][v])
            {
                if(w < dp[u][v])
                    ans += (sum - d[u][v] + w);
                else
                    ans += (sum - d[u][v] + dp[u][v]);
            }
            else
                ans += sum;
        }
        printf("%.4f\n", 1.0 * ans / q);
    }
    return 0;
}

 

 

 

 

 

 

 

 类似资料: