本文最后更新于:2020年8月24日 下午

问题的引入

一个n×mn \times m二维矩阵AA的前缀和SS,一般来说定义为Sx,y=i=1xj=1yAi,jS_{x,y}=\sum_{i=1}^{x}\sum_{j=1}^{y}A_{i,j}

代码常见的写法就是用容斥加加减减:

for(int i=1;i<=n;++i)
    for(int j=1;j<=m;++j)
        S[i][j]=S[i-1][j]+S[i][j-1]-S[i-1][j-1]+A[[i][j];

这是二维的情况,如果是三维或者是更高的 kk 维,不难发现复杂度有一个 2k2^k,指数爆炸了。

然而,用高维前缀和的技巧,就可以有效的将时间复杂度降下来。具体的思路是一维一维的求和。比如还是刚才二维的例子,为方便叙述,想象矩阵左上角是(1,1)(1,1),我们先对每行分别从左到右累加,即对每行求一维前缀和,再对每列从上到下累加,还是相同的过程,这样求出的结果就是二维前缀和了。也可以从定义前缀和的式子入手考虑,每个求和号就代表一个维度,我们的计算过程就相当于一个求和号一个求和号的算,比较简单。

高维前缀和代码,简洁的一匹:

for(int i=1;i<=n;++i)
    for(int j=1;j<=m;++j)
        S[i][j]=A[i][j];

for(int i=1;i<=n;++i)
    for(int j=1;j<=m;++j)
        S[i][j]+=S[i][j-1];
for(int i=1;i<=n;++i)
    for(int j=1;j<=m;++j)
        S[i][j]+=S[i-1][j];

子集和变换

高维前缀和在二进制数上的应用就是做子集和变换。比如有一些小于2202^{20}正整数给一个数xx,我要统计所有满足 x&y=yx\&y=y 的数 yy 有多少个,这里的and就是二进制与。这里其实还是高维前缀和,可以把每一个数看成2020维超立方体的其中一个格子,满足上述式子的xxyy的关系就是,每一个维度下yy对应的值都要小于等于xx对应的值(虽然每个维度只有两种值,0或1)这样就是相同的问题了,于是我们不仅可以维护一个子集的信息,还可以维护超集的信息,维护的内容也不只限于求和,还可以求最值等等。

求超集和:

void doit(int *f,int n)
{
    int len=1<<n;
    for(int i=0;i<n;++i)
        for(int j=0;j<len;++j)
            if(~j&(1<<i))
                f[j]+=f[j^(1<<i)];
}

第一层循环枚举维度,第二层枚举所有元素,注意第二层循环正序或者倒序都是一样的,因为每个维度大小只有2,即0或1,做的就是把“1”加到“0”上。

练习题目

1.SPOJ Time Limit Exceeded

简单的递推一下就是记dpi,jdp_{i,j}表示考虑前ii个数,第ii个数为jj的方案数,转移方程很容易dpi,j=k=02m1dpi1,k[j&k=0]dp_{i,j}=\sum_{k=0}^{2^m-1}dp_{i-1,k}[j\&k=0],并且如果j是cic_i的倍数,那么dpi,j=0dp_{i,j}=0

然而朴素的转移是O(n(2m)2)O(n(2^m)^2),显然会TLE。转化一下j&k=0j\& k=0等价于j&(k)=(k)j\&(\sim k)=(\sim k),那么我们把上一次的dp值状态取反,然后求超集和就ok了,时间复杂度O(n22m)O(n^{2}2^{m})

#include<bits/stdc++.h>
using namespace std;
const int mod=1e9;
#define N 55
#define M 16
int n,m;
int c[N];
int dp[1<<M];
void doit(int *f,int n)
{
    int len=1<<n;
    for(int i=0;i<n;++i)
        for(int j=0;j<len;++j)
            if(~j&(1<<i))
                f[j]=(f[j]+f[j^(1<<i)])%mod;
}
int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%d",&n,&m);
        for(int i=1;i<=n;++i) scanf("%d",c+i);
        int len=1<<m;
        for(int i=1;i<len;++i)dp[i]=0;
        dp[0]=1;
        for(int i=1;i<=n;++i)
        {
            for(int j=0;j<len;j+=2) swap(dp[j],dp[j^(len-1)]);
            doit(dp,m);
            for(int j=0;j<len;j+=c[i]) dp[j]=0;
        }
        int ans=0;
        for(int i=0;i<len;++i) ans=(ans+dp[i])%mod;
        printf("%d\n",ans);
    }
}

2.codeforces 449D - Jzzhu and Numbers

求and起来为0的子集的个数。

先用高维前缀和算出dpSdp_S表示状态包含S的数的个数,然后令2dpS12^{dp_S}-1就是and起来包含S的子集的个数,然后再高维差分回去,就能求出答案了。

#include<bits/stdc++.h>
using namespace std;
const int mod=1e9+7;
int n;
int dp[1<<20];
int power(int x,int n)
{
    int ans=1;
    while(n)
    {
        if(n&1) ans=1LL*ans*x%mod;
        x=1LL*x*x%mod;
        n>>=1;
    }
    return ans;
}
void doit(int *f,int n,int o)
{
    int len=1<<n;
    for(int i=0;i<n;++i)
        for(int j=0;j<len;++j)
            if(~j&(1<<i))
                f[j]=(f[j]+f[j^(1<<i)]*o)%mod;
}
int main()
{
    scanf("%d",&n);
    for(int i=0;i<n;++i)
    {
        int x;
        scanf("%d",&x);
        dp[x]++;
    }
    doit(dp,20,1);
    for(int i=0;i<(1<<20);++i) dp[i]=power(2,dp[i]);
    doit(dp,20,-1);
    printf("%d\n",(dp[0]+mod)%mod);
    return 0;
}

3.hihocoder 1496 寻找最大值

ai×aj×(ai&aj)a_i\times a_j \times (a_i \& a_j)的最大值,枚举&\&值,前面就选超集中的最大值和次大值,用高维前缀和处理出来。

#include<bits/stdc++.h>
using namespace std;
int n;
int f[1<<20];
int g[1<<20];

int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d",&n);
        int len=1<<20;
        for(int i=0;i<len;++i) f[i]=g[i]=0;
        for(int i=1;i<=n;++i)
        {
            int x;
            scanf("%d",&x);
            if(f[x]) g[x]=x;
            f[x]=x;
        }
        for(int i=0;i<20;++i)
            for(int j=0;j<len;++j)
                if((~j)&(1<<i))
                {
                    int k=j^(1<<i);
                    if(f[k]>f[j])
                    {
                        g[j]=max(f[j],g[k]);
                        f[j]=f[k];
                    }
                    else g[j]=max(g[j],f[k]);
                }
        long long ans=0;
        for(int i=0;i<len;++i)
            ans=max(ans,1LL*f[i]*g[i]*i);
        printf("%lld\n",ans);
    }
    return 0;
}

4.2016 Multi-University Training Contest 4 Bonds

给一个不超过20个点的无向连通图,无重边无自环,求每条边被多少个极小割边集包括。

显然能观察到极小割只会将图分成两个连通块,那么我们先用点的连通性作为状态bfs一下,得到所有可能的分法。之后如果暴力统计的话复杂度就是O(n22n)O(n^{2}2^{n}),我们计算反面,每条边极小割边集出出现的次数等于极小割总数减去这条边连接的两点在同一连通块的情况,然后就可以高维前缀和了,O(n2n)O(n2^{n})

#include<bits/stdc++.h>
using namespace std;
#define lowbit(x) (x&(-x))
int bs[1<<21];
int q[1<<21];
int ans[1<<21];
bool vis[1<<21];
int n,m;
int a[400],b[400];
void bfs()
{
    int l=0,r=0;
    for(int i=0;i<n;++i)
        q[r++]=1<<i,vis[1<<i]=true;
    while(l!=r)
    {
        int x=q[l++];
        int y=bs[x]&(~x);
        while(y)
        {
            if(!vis[x|lowbit(y)])
            {
                vis[x|lowbit(y)]=true;
                q[r++]=x|lowbit(y);
            }
            y-=lowbit(y);
        }
    }
}
int main()
{
    int T;
    scanf("%d",&T);
    for(int cas=1;cas<=T;++cas)
    {
        scanf("%d%d",&n,&m);
        int all=1<<n;
        for(int i=0;i<all;++i) bs[i]=0,vis[i]=0,ans[i]=0;
        for(int i=1;i<=m;++i)
        {
            scanf("%d%d",a+i,b+i);
            bs[1<<a[i]]|=1<<b[i];
            bs[1<<b[i]]|=1<<a[i];
        }
        for(int i=0;i<all;++i)
            bs[i]=bs[i^lowbit(i)]|bs[lowbit(i)];
        bfs();
        int tot=0;
        for(int i=0;i<all;++i)
        if(i<((all-1)^i) && vis[i] && vis[(all-1)^i])
            ans[i]++,ans[(all-1)^i]++,tot++;
        for(int i=0;i<n;++i)
            for(int j=0;j<all;++j)
                if(~j&(1<<i))
                    ans[j]+=ans[j^(1<<i)];
        printf("Case #%d:",cas);
        for(int i=1;i<=m;++i)
            printf(" %d",tot-ans[(1<<a[i])|(1<<b[i])]);
        puts("");
    }
    return 0;
}

5.2017-2018 Petrozavodsk Winter Training Camp, Saratov SU Contest F.GCD

先从nn个数中随机一个数,然后就有大于二分之一的概率选到了最优解中的一个数,那么枚举这个数的所有约数,用至少是nkn-k个数的约数的数更新答案,重复多做几次,降低随不到的概率。关键是怎么不暴力的做后面说的这个事情。首先,一个101810^{18}范围内的数的约数个数不会很多,大概几倍的10610^{6}就够,具体范围我记得有一个表格。然后将随到数分解质因数,每个不同的质因数看成一个维度,然后求个高维后缀和,就可以了。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define N 100010
int prime[1000010],tot;
bool vis[1000010];
ll a[N],ans;
int n,k;
void getprime()
{
    for(int i=2;i<=1000000;++i)
    {
        if(!vis[i]) prime[++tot]=i;
        for(int j=1;j<=tot && prime[j]<=1000000/i;++j)
        {
            vis[i*prime[j]]=true;
            if(i%prime[j]==0) break;
        }
    }
}
ll p[50];
int num[50];
int cnt;
ll dv[5000010];
int f[5000010];
int len;
void dfs(int x,ll now)
{
    if(x>cnt)
    {
        dv[++len]=now;
        return;
    }
    dfs(x+1,now);
    for(int i=1;i<=num[x];++i)
    {
        now*=p[x];
        dfs(x+1,now);
    }
}

void divide(ll x)
{
    for(int i=1;i<=tot;++i)
    if(x%prime[i]==0)
    {
        ++cnt;
        p[cnt]=prime[i];
        num[cnt]=0;
        while(x%prime[i]==0)
            x/=prime[i],++num[cnt];
    }
    for(int i=1;i<=n;++i)
    {
        ll g=__gcd(x,a[i]);
        if(g>1 && g<x)
        {
            if(g==x/g) ++cnt,p[cnt]=g,num[cnt]=2;
            else
            {
                p[++cnt]=g;num[cnt]=1;
                p[++cnt]=x/g;num[cnt]=1;
                if(p[cnt-1]>p[cnt]) swap(p[cnt-1],p[cnt]);
            }
            x=1;
            break;
        }
    }
    if(x>1) p[++cnt]=x,num[cnt]=1;
}
void work(ll now)
{
    cnt=len=0;
    divide(now);
    dfs(1,1);
    sort(dv+1,dv+len+1);
    for(int i=1;i<=len;++i) f[i]=0;
    for(int i=1;i<=n;++i)
    {
        ll g=__gcd(now,a[i]);
        ++f[lower_bound(dv+1,dv+len+1,g)-dv];
    }
    for(int i=1;i<=cnt;++i)
    {
        ll x=p[i];
        for(int j=len,k=len;j>=1;--j)
        if(dv[j]%x==0)
        {
            ll y=dv[j]/x;
            while(dv[k]>y) --k;
            f[k]+=f[j];
        }
    }
    for(int i=1;i<=len;++i)
        if(f[i]>=n-k)
            ans=max(ans,dv[i]);
}
int main()
{
    mt19937 rnd(time(0));
    getprime();
    scanf("%d%d",&n,&k);
    for(int i=1;i<=n;++i) scanf("%lld",a+i);
    ans=1;
    for(int i=1;i<=20;++i)
        work(a[rnd()%n+1]);
    printf("%lld\n",ans);
    return 0;
}

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!

一类区间分治技巧 上一篇
线段树优化凸壳 下一篇