下面给出Splay的实现方法(复杂度证明什么的知道是 nlogn 就可以啦)
首先对于一颗可爱的二叉查找树,是不能保证最坏nlogn的复杂度(可以想象把一个升序序列插入)
所以我们需要一些非常巧妙的旋转操作
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<cmath> 5 #define N 100010 6 #define which(x) (ls[fa[(x)]]==(x)) 7 typedef long long ll; 8 using namespace std; 9 int n,root,idx,val[N],fa[N],ls[N],rs[N],sze[N],cnt[N]; 10 int read() 11 { 12 int ret=0,neg=1; 13 char j=getchar(); 14 for (;j>‘9‘ || j<‘0‘;j=getchar()) 15 if (j==‘-‘) neg=-1; 16 for (;j>=‘0‘ && j<=‘9‘;j=getchar()) 17 ret=ret*10+j-‘0‘; 18 return ret*neg; 19 } 20 void upt(int x)//更新子树大小 21 { 22 sze[x]=sze[ls[x]]+sze[rs[x]]+cnt[x]; 23 } 24 void rotate(int x)//旋转操作 25 { 26 //y是x父亲,z是y父亲,b是y的另一个儿子 27 int y=fa[x],z=fa[y],b=which(x)?rs[x]:ls[x],dir=which(y); 28 which(x)?(rs[x]=y,ls[y]=b):(ls[x]=y,rs[y]=b); 29 fa[y]=x,fa[b]=y,fa[x]=z; 30 if (z) dir?ls[z]=x:rs[z]=x; 31 upt(y),upt(x);//更新大小 32 } 33 void splay(int x)//把x旋转至根节点 34 { 35 //为了让树平衡,如果x和父亲同向,转fa[x]染红转x 36 //否则转两次x 37 while (fa[x]) 38 { 39 if (fa[fa[x]]) 40 if (which(x)==which(fa[x])) rotate(fa[x]); 41 else rotate(x); 42 rotate(x); 43 } 44 root=x;//现在x是根了 45 } 46 int getmin(int x)//找以x为根子树最小值节点编号 47 { 48 while (ls[x]) x=ls[x]; 49 return x; 50 } 51 int getmax(int x)//找以x为根子树最大值节点编号 52 { 53 while (rs[x]) x=rs[x]; 54 return x; 55 } 56 int find(int x)//找值为x的节点没有返回0 57 { 58 int cur=root,last=0; 59 while (cur && val[cur]!=x) 60 { 61 last=cur; 62 if (x<val[cur]) cur=ls[cur]; 63 else cur=rs[cur]; 64 } 65 return cur?cur:last; 66 } 67 void insert(int x)//插入x 68 { 69 int cur=find(x);//找到 70 //如果已经存在x,把x++后splay成根节点 71 if (cur && val[cur]==x) return (void)(cnt[cur]++,sze[cur]++,splay(cur)); 72 //如果不存在x就创造一个,然后splay 73 val[++idx]=x,fa[idx]=cur,cnt[idx]=sze[idx]=1; 74 if (cur) x<val[cur]?ls[cur]=idx:rs[cur]=idx; 75 splay(idx); 76 } 77 void erase(int x)//删除值为x的节点 78 { 79 int cur=find(x);//保证存在 80 splay(cur);//先把x转到根 81 //如果x个数大于1,直接删掉就好 82 if (cnt[cur]>1) cnt[cur]--,sze[cur]--; 83 //如果有一个儿子节点为空,直接让另一个为根,如果都是空就说明树为空 84 else if (!ls[cur] || !rs[cur]) root=ls[cur]+rs[cur],fa[root]=0; 85 else 86 { 87 fa[ls[cur]]=0;//x的左儿子没爸爸了 88 int u=getmax(ls[cur]);//让左子树最大值节点当新根节点,右子树的根节点是新根节点的右儿子 89 splay(u); 90 rs[u]=rs[cur],fa[rs[cur]]=u; 91 upt(u); 92 } 93 } 94 int getkth(int k)//寻找第k大,比较easy 95 { 96 int cur=root; 97 while (cur) 98 { 99 if (sze[ls[cur]]>=k) cur=ls[cur]; 100 else if (sze[ls[cur]]+cnt[cur]>=k) return val[cur]; 101 else k-=sze[ls[cur]]+cnt[cur],cur=rs[cur]; 102 } 103 return val[cur]; 104 } 105 int getrank(int x)//询问x排名 106 { 107 int cur=find(x); 108 splay(cur); 109 return sze[ls[cur]]+1; 110 } 111 int getpre(int x)//找前驱 112 { 113 int cur=find(x); 114 if (val[cur]<x) return val[cur]; 115 splay(cur); 116 return val[getmax(ls[cur])]; 117 } 118 int getnxt(int x)//找后继 119 { 120 int cur=find(x); 121 if (val[cur]>x) return val[cur]; 122 splay(cur); 123 return val[getmin(rs[cur])]; 124 } 125 int main() 126 { 127 n=read(); 128 for (int i=1,op,x;i<=n;i++) 129 { 130 op=read(),x=read(); 131 if (op==1) insert(x); 132 if (op==2) erase(x); 133 if (op==3) printf("%d\n",getrank(x)); 134 if (op==4) printf("%d\n",getkth(x)); 135 if (op==5) printf("%d\n",getpre(x)); 136 if (op==6) printf("%d\n",getnxt(x)); 137 } 138 return 0; 139 }