虚树


先由一道例题引入:消耗战

对于这个问题的每个询问,我们直接树形$DP$就很容易解决了,然而,问题就在于如果我们每次$DP$都遍历整个树,那时间复杂度就爆炸了,但我们发现我们每次询问涉及到的点很少,这时就引入了我们的虚树,我们只需要把我们需要用到的点建立一颗大小为$O(k)$级别的虚树,然后在虚树上$DP$即可,所以我们的难题就变为了如何建立虚树。对于此题,我们发现只要我们预处理出$minv[i]$表示根节点到$i$节点上的最短边,然后我们就可以只靠这些询问点和它们的最近公共祖先就能求出答案,由于只有$k$个叶子节点,单次询问时虚树最多只有$2k-1$个节点(当虚树为完全二叉树时取到)。

在建立虚树前,我们先处理出$dfs$序,和倍增求$lca$用到的一些数据,当我们处理一个询问时,我们将涉及到的点按$dfs$序排序,同时我们用栈维护一条最右链(这条链表示左侧的虚树已经完成构建,而构建的过程中随时可能会有某个$lca$插入最右链中,因此最右链最后再加入虚树)。

初始时,我们直接将第一个询问点加入栈中,然后顺次考虑接下来的询问点,假设当前的询问点为$now$,$anc=lca(now,stk[top])$,我们根据$anc、stk[top]、stk[top-1]$之间的关系,并进行分类讨论。

  1. $anc=stk[top]$,这说明$now$在$anc$的子树中,直接把它加入栈中(最右链中)即可
    image-20220829091543591

  2. $anc$在$stk[top]$和$stk[top-1]$之间(根据$dep$判断)

    image-20220829091842704

    此时我们的$stk[top]$从最右链中出来,而$now$进入了最右链中,在栈中进行相应操作并加一条从$anc$到$stk[top]$的边即可

  3. $anc=stk[top-1]$,只需把$stk[top]$出栈,$now$入栈,并加一条从$anc$到$stk[top]$的边即可

    image-20220829092155517

  4. $anc$是$stk[top-1]$的祖先,即$dep[anc]<dep[stk[top-1]]$,我们需要不断出栈并把边加入虚树中,直到不再是该情况

    image-20220829092530652

  5. 小细节,当我们栈中只有一次元素时,很容易知道我们只会进入情况$1、2$

另外此题我们清空虚树时是边$dfs$边清空,以此保证时间复杂度。然后就奉上本题代码,细节在注释里。

#include<bits/stdc++.h>
#define x first 
#define y second
#define endl '\n'
#define lowbit(x) ((x)&(-x))
#define all(x) (x).begin(),(x).end()
#define mp make_pair
#define pb push_back
#define SZ(x) ((int)(x).size())
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
using namespace std;
typedef pair<int,int> PII;
typedef vector<int> VI;
typedef long long LL;
typedef unsigned long long ULL;
typedef double DB;
mt19937 randint(random_device{}());
mt19937_64 randLL(random_device{}());
template<typename T=int> 
T gcd(T a,T b){return b?gcd(b,a%b):a;}
template<typename T=int> 
T qmi(T a,T b,T p){T res=1;while(b){if(b&1)res=(LL)res*a%p;b>>=1;a=(LL)a*a%p;}return res%p;}
template<typename T=int>
T read(){
	T x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-')
			f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		x=x*10+ch-'0';
		ch=getchar();
	}
	return x*f;
}
//----------------------------------------head------------------------------------
const LL INF=0x3f3f3f3f3f3f3f3f;
const int N=3e5+10;
int n,m,stk[N],top,dfn[N],dep[N],fa[N][20],num,st[N];
LL minv[N];//开LL
vector<PII> to[N];//原树
VI vt[N];//虚树
bool que[N];
void dfs(int u,int fat){
	dep[u]=dep[fat]+1;
	dfn[u]=++num;
	fa[u][0]=fat;
	for(int i=1;i<20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
	for(auto edge:to[u]){
		int j=edge.x,w=edge.y;
		if(j==fat) continue;
		minv[j]=min(minv[u],1ll*w);
		dfs(j,u);
	}
}
int lca(int a,int b){
	if(dep[a]<dep[b]) swap(a,b);
	for(int k=19;k>=0;k--) if(dep[fa[a][k]]>=dep[b]) a=fa[a][k];
	if(a==b) return a;
	for(int k=19;k>=0;k--)
		if(fa[a][k]!=fa[b][k])
			a=fa[a][k],b=fa[b][k];
	return fa[a][0];
}
LL dfsans(int u){//在虚树上DP
	LL sum=0,ans=0;
	for(auto j:vt[u]) sum+=dfsans(j);
	if(que[u]) ans=minv[u];
	else ans=min(minv[u],sum);
	que[u]=0;//去掉标记
	vt[u].clear();//清空虚树
	return ans;
}
signed main(){
    n=read();
	for(int i=1;i<n;i++){
		int u=read(),v=read(),w=read();
		to[u].pb({v,w});
		to[v].pb({u,w});
	}
	minv[1]=INF;
	dfs(1,0);
	m=read();
	while(m--){
		int num=read();
		for(int i=1;i<=num;i++) st[i]=read(),que[st[i]]=1;//que[i]为1表示i是资源丰富的岛屿
		sort(st+1,st+num+1,[](int a,int b){
			return dfn[a]<dfn[b];
		});
		stk[top=1]=st[1];
		for(int i=2;i<=num;i++){
			int now=st[i];
			int anc=lca(now,stk[top]);
			while(1){
				if(dep[anc]>=dep[stk[top-1]]){
					if(anc!=stk[top]){
						vt[anc].pb(stk[top]);
						if(anc!=stk[top-1]) stk[top]=anc;//情况2
						else top--;//情况3
					}
					break;//情况1
				}
				else{//情况4
					vt[stk[top-1]].pb(stk[top]);
					top--;
				}
			}
			stk[++top]=now;
		}
		while(--top) vt[stk[top]].pb(stk[top+1]);//加入最右链
		printf("%lld\n",dfsans(stk[1]));//stk[1]就相当于虚树的根
	}
	return 0;
}

再来道练习题:世界树

对于本题,答案的计算没有上一题那么简单,我们可以先用两遍$dfs$求出距离每个点最近的临时议事处和这个距离(一遍求儿子中距离最近的,一遍求父亲中距离最近的),然后我们发现,对于不在虚树上的点,可以分为两种情况,一种是在虚树路径上的点,一种是不在虚树路径上的点。不在虚树路径上的点,就向距离它们最近的虚树上的点贡献答案,在虚树路径上的点就向这条路径两端的点贡献答案,且具有二分性(我们直接用倍增来求)。

image-20220829142925555

#include<bits/stdc++.h>
#define x first
#define y second
#define endl '\n'
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define lowbit(x) ((x)&-(x))
#define SZ(x) (int)x.size()
using namespace std;
using ll=long long;
using ld=long double;
using pii=pair<int,int>;
using pll=pair<ll,ll>;
using pdd=pair<ld,ld>;
const int INF=0x3f3f3f3f;
const int N=3e5+10;
vector<int> vt[N];
vector<int> to[N];
int st[N],stk[N],top,que[N],cnt,dfn[N],dep[N],fa[N][20],timestamp,n;
int dist[N],id[N],ans[N],pre[N],sz[N];
void dfs(int u,int fat){
	dfn[u]=++timestamp;
	sz[u]=1;
	dep[u]=dep[fat]+1;
	fa[u][0]=fat;
	for(int i=1;i<20;i++)
		fa[u][i]=fa[fa[u][i-1]][i-1];
	for(auto j:to[u]){
		if(j==fat) continue;
		dfs(j,u);
		sz[u]+=sz[j];
	}
}
int lca(int a,int b){
	if(dep[a]<dep[b]) swap(a,b);
	for(int k=19;~k;k--)
		if(dep[fa[a][k]]>=dep[b]) a=fa[a][k];
	if(a==b) return a;
	for(int k=19;~k;k--)
		if(fa[a][k]!=fa[b][k])
			a=fa[a][k],b=fa[b][k];
	return fa[a][0];
}
void dfsson(int u){
	dist[u]=INF;
	ans[u]=0;
	if(que[u]) dist[u]=0,id[u]=u;
	for(auto j:vt[u]){
		dfsson(j);
		int w=dep[j]-dep[u];
		if(dist[j]+w<dist[u]) dist[u]=dist[j]+w,id[u]=id[j];
		else if(dist[j]+w==dist[u]&&id[j]<id[u]) id[u]=id[j];
	}
}
void dfsfa(int u){
	for(auto j:vt[u]){
		int w=dep[j]-dep[u];
		if(dist[u]+w<dist[j]) dist[j]=dist[u]+w,id[j]=id[u];
		else if(dist[u]+w==dist[j]&&id[u]<id[j]) id[j]=id[u];
		dfsfa(j);
	}
}
int getpos(int u,int j){
	for(int i=19;~i;i--)
		if(dep[fa[j][i]]>dep[u])
			j=fa[j][i];
	return j;
}
int getbound(int u,int j){
	int str=j;
	if(id[u]<id[j]){
		for(int i=19;~i;i--)
			if(dist[str]+dep[str]-dep[fa[j][i]]<dist[u]+dep[fa[j][i]]-dep[u])
				j=fa[j][i];
		return j;
	}
	else{
		for(int i=19;~i;i--)
			if(dist[str]+dep[str]-dep[fa[j][i]]<=dist[u]+dep[fa[j][i]]-dep[u])
				j=fa[j][i];
		return j;
	}
}
void dfsans(int u){
	int sum=0;
	for(auto j:vt[u]){
		int x=getpos(u,j);//找到原树中u下面一个点,这样处理会容易很多
		sum+=sz[x];
		if(id[u]==id[j]) ans[id[u]]+=sz[x]-sz[j];
		else{
			int y=getbound(u,j);
			ans[id[u]]+=sz[x]-sz[y];
			ans[id[j]]+=sz[y]-sz[j];
		}
		dfsans(j);
	}
	vt[u].clear();
	que[u]=0;
	ans[id[u]]+=sz[u]-sum;
}
void solve(){
	cin>>n;
	for(int i=1,a,b;i<n;i++){
		cin>>a>>b;
		to[a].push_back(b);
		to[b].push_back(a);
	}
	dfs(1,0);
	cin>>cnt;
	while(cnt--){
		int num,prenum;
		cin>>num;
		prenum=num;
		bool flag=0;//给定的点中是否有1
		for(int i=1;i<=num;i++){
			cin>>st[i],pre[i]=st[i],que[st[i]]=1;
			if(st[i]==1) flag=1;
		}
		if(!flag) st[++num]=1;//没有1的话就自己添加一个1,这样处理会比较简单
		sort(st+1,st+num+1,[](int a,int b){return dfn[a]<dfn[b];});
		stk[top=1]=st[1];
		for(int i=2;i<=num;i++){
			int now=st[i];
			int anc=lca(now,stk[top]);
			while(1){
				if(dep[anc]>=dep[stk[top-1]]){
					if(anc!=stk[top]){
						vt[anc].pb(stk[top]);
						if(anc!=stk[top-1]) stk[top]=anc;
						else top--;
					}
					break;
				}
				else{
					vt[stk[top-1]].pb(stk[top]);
					top--;
				}
			}
			stk[++top]=now;
		}
		while(--top) vt[stk[top]].pb(stk[top+1]);
		dfsson(1);
		dfsfa(1);
		dfsans(1);
		for(int i=1;i<=prenum;i++) cout<<ans[pre[i]]<<" \n"[i==prenum];
	}
}
signed main(){
	ios::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);
	solve();
	return 0;
}

文章作者: verynewabie
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 verynewabie !
评论
  目录