简单介绍
伸展树($splay$),也叫分裂数,是一种平衡二叉树,能在$O(log\ n)$的时间复杂度内完成插入,查找和删除操作,比较好写而且很实用,$LCT$也经常借助$splay$来实现,是一种在竞赛中比较常用的数据结构。本篇就来给大家介绍一下这个数据结构,关于时间复杂度的证明本人就不再赘述(有点难)。读者应当了解二叉搜索树及其缺陷。
平衡树的左旋和右旋
对一棵二叉搜索树,当我们进行如下的旋转时,我们发现各节点之间的大小关系依然是正确的,因而我们得以通过旋转保证二叉搜索树的平衡(高度接近于$log\ n$)以保证插入、查找和删除的时间复杂度。
而我们的$splay$,每操作一个节点就把它旋转到根节点,从而保证了时间复杂度(我不太会证)。
维护信息
对于$splay$的一个节点,我们需要维护:子树大小、两个儿子、父亲和该点权值(根据题目不同还要维护其它的信息)
struct node{
int s[2],p;//父亲为0的就是根节点
int val,sz;
}tr[N];
因为我们要旋转,旋转后要更新相关信息,大小的维护我们使用一个$pushup$函数
//为了方便
#define ls tr[u].s[0]
#define rs tr[u].s[1]
void pushup(int u){
tr[u].sz=tr[ls].sz+tr[rs].sz+1;
}
新建节点
int newnode(int val,int p){
int u=++idx;
tr[u].p=p;
tr[u].val=val;
return u;
}
需要节点的权值和父亲信息
单次旋转
非常简单的写法,省去了繁琐的分类讨论,可以自己代入各种情况试一下,都是对的,本人这里就不赘述。
void rotate(int x){//这是个工具函数,我们不会直接调用它,其它函数会调用这个函数
int y=tr[x].p,z=tr[y].p;//其它函数会保证只有z可能为0,这时也是正确的
tr[z].s[y==tr[z].s[1]]=x,tr[x].p=z;//虽然我们可能会修改tr[0]的儿子信息,但我们用不到它
int k=x==tr[y].s[1];
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y,tr[y].p=x;
pushup(y),pushup(x);//通过上图我们可以发现,只有x节点和它的父节点的信息需要pushup
}
旋转
这里的旋转是将$x$点旋转到另一点$k$下面。我们根据$x$、它的父亲$y$和$y$的父亲$z$三者间的关系来选择不同的旋转方式。(这样可以保证时间复杂度,同样的:不会证)
- $x,y,z$在一条直线上,这时我们先旋转$y$,再旋转$x$
- $x,y,z$不在一条直线上,这时我们先旋转$x$,再旋转$x$
同时我们还要注意,如果$z$已经是$k$了,那我们只要旋转一次$x$就可以了
void splay(int x,int k){
while(tr[x].p!=k){
int y=tr[x].p,z=tr[y].p;
if(z!=k){
if((tr[z].s[1]==y)^(tr[y].s[1]==x)) rotate(x);//比较方便的判断
else rotate(y);
}
rotate(x);
}
if(k==0) root=x;//把x旋转到0下面,说明0变为了根节点
}
建树
当它一次性给我们一些有序节点时,我们可以$O(n)$建树
int build(int l,int r,int p){
int mid=l+r>>1;
int u=newnode(w[mid],p);
if(l<mid) ls=build(l,mid-1,u);
if(mid<r) rs=build(mid+1,r,u);
pushup(u);
return u;
}
插入节点
int insert(int val){
int u=root,p=0;
while(u){
p=u;
u=tr[u].s[val>tr[u].val];//根据权值判断该进入左子树还是右子树
}
u=newnode(val,p);
if(p) tr[p].s[val>tr[p].val]=u;//如果不是根节点,更新父节点的儿子信息
splay(u,0);//我们不用pushup(p)的原因就在于我们会把u转到根,路径上的点都会更新
return u;
}
查询
这里介绍查找树中查找第$k$小数的方法,其它也类似。
int get_kth(int k){//找到第k小的数对应节点编号
int u=root;
while(u){
if(tr[ls].sz>=k) u=ls;
else if(tr[ls].sz+1==k) return u;
else k-=tr[ls].sz+1,u=rs;
}
return -1;//不足k个数
}
再介绍一下查后继的方法(查找大于等于$val$的第一个数)
int get(int val){//找到大于等于某个数的第一个数,同样是返回编号
int u=root,res=-1;//不存在答案就返回-1
while(u){
if(tr[u].val>=val) res=u,u=ls;
else u=rs;
}
return res;
}
删除
void del(int& root,int x){
int u=root;
while(u){
if(tr[u].val==x) break;
if(tr[u].val<x) u=rs;
else u=ls;
}
//若x不存在,则u在这会是0
if(u==0) return ;
splay(u,0);
int l=ls,r=rs;
while(tr[l].s[1]) l=tr[l].s[1];//找到该点前驱
while(tr[r].s[0]) r=tr[r].s[0];//找到该点后继
splay(l,0),splay(r,l);//将l转到根,r转到l下面,而r节点值大于l节点,此时r的左儿子就是要删的数
tr[r].s[0]=0;
pushup(r),pushup(l);//更新需要更新的l和r
}
文艺平衡树
$splay$还可以通过$pushdown$支持一堆骚操作,我们通过例题来介绍其中的一个。
这题我们可以直接维护整个序列的中序遍历,可以看成每个点的权值就是它的下标($1\sim n$),当我们需要翻转一个区间时,我们就把这个区间“取出”,然后交换并加上懒标记,至于如何操作的,详见代码。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
LL read(){
LL x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
const int N=1e5+10;
int n,m,root,idx;
struct node{
int s[2],p,val,sz,flag;
}tr[N];
void pushup(int u){
tr[u].sz=tr[tr[u].s[0]].sz+tr[tr[u].s[1]].sz+1;
}
void pushdown(int u){
if(tr[u].flag){
swap(tr[u].s[0],tr[u].s[1]);
tr[tr[u].s[0]].flag^=1;
tr[tr[u].s[1]].flag^=1;
tr[u].flag=0;
}
}
void rotate(int x){
int y=tr[x].p,z=tr[y].p;
int k=tr[y].s[1]==x;
tr[z].s[y==tr[z].s[1]]=x;
tr[x].p=z;
tr[y].s[k]=tr[x].s[k^1];
tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y;
tr[y].p=x;
pushup(y);
pushup(x);
}
void splay(int x,int k){
while(tr[x].p!=k){
int y=tr[x].p,z=tr[y].p;
if(z!=k){
if((tr[z].s[1]==y)^(tr[y].s[1]==x)) rotate(x);
else rotate(y);
}
rotate(x);
}
if(k==0) root=x;
}
void insert(int val){
int u=root,p=0;
while(u){
p=u;
u=tr[u].s[val>tr[u].val];
}
u=++idx;
if(p) tr[p].s[val>tr[p].val]=u;
tr[u].p=p;
tr[u].val=val;
tr[u].sz=1;
splay(u,0);
}
int get_k(int k){
int u=root;
while(1){
pushdown(u);//查找前别忘了pushdown
if(k<=tr[tr[u].s[0]].sz) u=tr[u].s[0];
else if(tr[tr[u].s[0]].sz+1==k) return u;
else k-=tr[tr[u].s[0]].sz+1,u=tr[u].s[1];
}
}
void print(int u){
//打印前别忘了pushdown
pushdown(u);
if(tr[u].s[0]) print(tr[u].s[0]);
if(tr[u].val>=1&&tr[u].val<=n) printf("%d ",tr[u].val);
if(tr[u].s[1]) print(tr[u].s[1]);
}
signed main(){
n=read(),m=read();
for(int i=0;i<=n+1;i++) insert(i);//加入两个哨兵0和n+1,这样更方便
while(m--){
int l=read(),r=read();
//另外补充,我们在执行get_k函数时就把路径上的懒标记清空了,因此rotate和splay时都不用pushdown
l=get_k(l),r=get_k(r+2);//因为我们加入了哨兵,这里实际上是找第l-1个数和第r+1个数
splay(l,0);//l转到根
splay(r,l);//r转到l下面
//r的左儿子的子树就是l~r区间了
tr[tr[r].s[0]].flag^=1;//加上懒标记
}
print(root);
return 0;
}