本文最后更新于:2020年4月19日 晚上

适用条件:动态开点的(权值)线段树。

关于时间复杂度的结论:

  • 每次合并的代价是两棵树的公共节点数。
  • 若有n棵含有单个元素的树,经过n-1次merge操作,将他们合并成一棵树的代价是O(nlogn)O(nlogn)O(nlogC)O(nlogC)
  • 单次merge操作开销可大可小,均摊下一次就是一个log的。

关于空间复杂度,普通版本是O(nlogn)O(nlogn)的。

直接给出代码,merge操作十分简洁:

void insert(int &x,int y,int l,int r)
{
    ++sum[x=++tot];
    if(l==r) return;
    int mid=l+r>>1;
    if(y<=mid) insert(ls[x],y,l,mid);
    else insert(rs[x],y,mid+1,r);
}
int merge(int x,int y)
{
    if(!x||!y) return x+y;
    sum[x]=sum[x]+sum[y];
    ls[x]=merge(ls[x],ls[y]);
    rs[x]=merge(rs[x],rs[y]);
    return x;
}

题目

1.[BZOJ 2212: Poi2011]Tree Rotations

递归的给一颗二叉树,只有叶子有权值,对于每个非叶节点可以交换左右子树,使遍历后构成的序列逆序对最小。

对于一个子树,逆序对来自三部分,一个完全在左子树,一个完全在右子树,还有就是跨越左右子树的逆序对,可以发现交换左右子树的操作只会改变跨越部分的贡献,这里可以在线段树合并过程中求出逆序对的个数。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define N 600010
int n,m;
int son[N][2];
int root[N],rt;
int tot,ls[N*20],rs[N*20],sum[N*20];
ll ans0,ans1,ans;
void insert(int &x,int y,int l,int r)
{
    ++sum[x=++tot];
    if(l==r) return;
    int mid=l+r>>1;
    if(y<=mid) insert(ls[x],y,l,mid);
    else insert(rs[x],y,mid+1,r);
}
void read(int &x)
{
    x=++m;
    int y;
    scanf("%d",&y);
    if(y)
    {
        insert(root[x],y,1,n);
        return;
    }
    read(son[x][0]);
    read(son[x][1]);
}
int merge(int x,int y)
{
    if(!x||!y) return x+y;
    ans0+=1LL*sum[rs[x]]*sum[ls[y]];
    ans1+=1LL*sum[rs[y]]*sum[ls[x]];
    sum[x]=sum[x]+sum[y];
    ls[x]=merge(ls[x],ls[y]);
    rs[x]=merge(rs[x],rs[y]);
    return x;
}
void solve(int x)
{
    if(!son[x][0] && !son[x][1]) return;
    solve(son[x][0]);
    solve(son[x][1]);
    ans0=ans1=0;
    root[x]=merge(root[son[x][0]],root[son[x][1]]);
    ans+=min(ans0,ans1);
}
int main()
{
    scanf("%d",&n);
    read(rt);
    solve(rt);
    printf("%lld\n",ans);
    return 0;
}

2.[BZOJ 2733: HNOI2012]永无乡

每个元素有点权,支持合并两个集合,查询一个集合内第k小的元素。

之前的做法是启发式合并+平衡树,每个元素最多插入O(logn)O(logn)次,插入一次O(logn)O(logn),所以总的时间复杂度O(nlog2n)O(nlog^{2}n)

用线段树合并就可以做到O(nlogn)O(nlogn)了,查询非常简单。

#include<bits/stdc++.h>
using namespace std;
#define N 100010
int tot;
int sum[N*20],ls[N*20],rs[N*20];
int fa[N],siz[N],root[N],id[N],val[N];
int n,m,q;
void insert(int &x,int y,int l,int r)
{
    ++sum[x=++tot];
    if(l==r) return;
    int mid=l+r>>1;
    if(y<=mid) insert(ls[x],y,l,mid);
    else insert(rs[x],y,mid+1,r);
}
int merge(int x,int y)
{
    if(!x||!y) return x+y;
    sum[x]=sum[x]+sum[y];
    ls[x]=merge(ls[x],ls[y]);
    rs[x]=merge(rs[x],rs[y]);
    return x;
}
int ask(int x,int y,int l,int r)
{
    if(l==r) return l;
    int mid=l+r>>1;
    if(y<=sum[ls[x]]) return ask(ls[x],y,l,mid);
    else return ask(rs[x],y-sum[ls[x]],mid+1,r);
}
int getfa(int x)
{
    return x==fa[x]?x:fa[x]=getfa(fa[x]);
}
void Union(int x,int y)
{
    x=getfa(x);y=getfa(y);
    if(x==y) return;
    if(siz[x]>siz[y]) swap(x,y);
    fa[x]=y;
    siz[y]+=siz[x];
    root[y]=merge(root[x],root[y]);
}
int Ask(int x,int y)
{
    x=getfa(x);
    if(siz[x]<y) return -1;
    return id[ask(root[x],y,1,n)];
}

int main()
{
    scanf("%d%d",&n,&m);
    tot=0;
    for(int i=1;i<=n;++i)
    {
        scanf("%d",&val[i]);id[val[i]]=i;
        fa[i]=i;
        siz[i]=1;
        insert(root[i],val[i],1,n);
    }
    for(int i=1;i<=m;++i)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        Union(x,y);
    }
    scanf("%d",&q);
    while(q--)
    {
        char opt[3];
        int x,y;
        scanf("%s%d%d",opt,&x,&y);
        if(opt[0]=='Q') printf("%d\n",Ask(x,y));
        else Union(x,y);
    }
    return 0;
}

3.2018 南京网络赛 H set

首先这个题不是线段树,而是trie树,我们把每个数二进制从低位到高位插入trie,维护子树中插入的数的个数,合并类似线段树合并,修改操作比较巧妙,+1就相当于最低位+1,也就是最低位0变1,1变0,那么交换左右子树即可,进位的话就是进入(交换后)的左子树继续交换左右子树,递归下去。这样总的时间复杂度就是O(nlogC)O(nlogC)

#include<bits/stdc++.h>
using namespace std;
#define N 600010
int tot;
int sum[N*32],son[N*32][2];
int fa[N],root[N];
int n,m;
void read(int &x)
{
    char ch;
    while(!isdigit(ch=getchar()));
    x=ch-'0';
    while(isdigit(ch=getchar()))
        x=x*10+ch-'0';
}

int insert(int y)
{
    int now=++tot;
    int res=now;
    ++sum[now];
    for(int i=0;i<32;++i,y>>=1)
    {
        son[now][y&1]=++tot;
        ++sum[now=son[now][y&1]];
    }
    return res;
}
int merge(int x,int y)
{
    if(!x||!y) return x+y;
    sum[x]=sum[x]+sum[y];
    son[x][0]=merge(son[x][0],son[y][0]);
    son[x][1]=merge(son[x][1],son[y][1]);
    return x;
}
void change(int x)
{
    if(x==0) return;
    swap(son[x][0],son[x][1]);
    change(son[x][0]);
}
int ask(int x,int k,int y)
{
    for(int i=0;i<k;++i,y>>=1)
        x=son[x][y&1];
    return sum[x];
}
int getfa(int x)
{
    return x==fa[x]?x:fa[x]=getfa(fa[x]);
}
void Union(int x,int y)
{
    x=getfa(x);y=getfa(y);
    if(x==y) return;
    fa[x]=y;
    root[y]=merge(root[x],root[y]);
}
int main()
{
    read(n);read(m);
    for(int i=1;i<=n;++i)
    {
        int y;read(y);
        root[i]=insert(y);
        fa[i]=i;
    }
    while(m--)
    {
        int opt;
        scanf("%d",&opt);
        if(opt==1)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            Union(x,y);
        }
        else if(opt==2)
        {
            int x;
            scanf("%d",&x);
            change(root[getfa(x)]);
        }
        else
        {
            int x,k,y;
            scanf("%d%d%d",&x,&k,&y);
            printf("%d\n",ask(root[getfa(x)],k,y));
        }
    }
    return 0;
}