CodeForces-294E Shaass the Great

题目大意:
给出一棵n 个节点树
去掉其中一条边,再重新加入一条长度相同的边
问新树中任意两点之间距离的总和最小是多少

原树去掉一条边之后变为两棵树,分别记为AB,它们的节点个数分别为$A_{sz}$,$B_{sz}$
设去掉边长度为d,新连接的两点分别为uv
A 中所有点到u 距离之和为$S_{u}$,同理$S_{v}$
A 中任意两点的距离之和为$C_{A}$,同理$C_{B}$

新树中任意两点的距离和分两部分考虑

  • AB 树内,即$C_{A}+C_{B}$
  • 在两棵树之间

    例如$x-u-v-y$,其中xA 中的节点,yB 中的节点
    将这样的路径分三部分考虑

    • 首先是$u-v$,不难发现共经过$A_{sz}*B_{sz}$次
    • 然后是$x-u$,每个$x-u$都会经过$B_{sz}$次
    • 同理$v-y$

全部加起来就是

观察到$C_{A},C_{B},d,A_{sz},B_{sz}$与$u,v$无关
因此只需找$S_{u},S_{v}$最小的$u,v$

关于$S$的计算分为两个部分,以A 树为例

  • 第一遍dfs 求出每个节点x 的子树大小和子树中节点到x 的距离和

  • 第二遍dfs

    ​ 就是把中心点从x 移到$x_{son}$,$\left (A_{sz}-sz_{x_{son}} \right )$多经过一条边,$sz_{x_{son}}$少经过一条边

有了$S$,$C$就好计算了

每条边$x-y$,会在xy 各计算一次

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<cstdio>
#include<vector>
#define LL long long
using namespace std;
const int N=5050;
const LL INF=1LL<<60;
struct edge{int u,v,val;}t[N];
LL sz[N],c[N],ans=INF;//c[x] 所有点到x的sum
LL s1,s2,r1,r2,n,size;//sum{Any 2} min{Any c[x]}
vector<int> e[N],g[N];
inline int read()
{
register int x=0,t=1;
register char ch=getchar();
while (ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
if (ch=='-') t=-1,ch=getchar();
while (ch>='0'&&ch<='9') x=x*10+ch-48,ch=getchar();
return x*t;
}
void add(int u,int v,int val)
{
e[u].push_back(v);
g[u].push_back(val);
}
void dfs(int o,int fa)
{
sz[o]=1,c[o]=0;
for(int i=0;i<e[o].size();i++)
{
int to=e[o][i];
if (to!=fa)
{
dfs(to,o);
sz[o]+=sz[to];
c[o]+=c[to]+sz[to]*g[o][i];
}
}
}
void calc(int o,int fa,LL &sum,LL &x)
{
sum+=c[o],x=min(x,c[o]);
for(int i=0;i<e[o].size();i++)
{
int to=e[o][i];
if (to!=fa)
{
c[to]=c[o]+(size-sz[to]*2)*g[o][i];
calc(to,o,sum,x);
}
}
}
int main()
{
n=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
int val=read();
t[i]=(edge){u,v,val};
add(u,v,val);
add(v,u,val);
}
for(int i=1;i<n;i++)
{
r1=r2=INF,s1=s2=0;
int u=t[i].u,v=t[i].v;
dfs(u,v),size=sz[u];
calc(u,v,s1,r1),s1/=2;
dfs(v,u),size=sz[v];
calc(v,u,s2,r2),s2/=2;
ans=min(ans,sz[u]*sz[v]*t[i].val+s1+s2+r1*sz[v]+r2*sz[u]);
}
printf("%lld\n",ans);
return 0;
}