bzoj 3509: [CodeChef] COUNTARI — 分块+FFT
Contents
3509: [CodeChef] COUNTARI
Time Limit: 40 Sec Memory Limit: 128 MB
Description
给定一个长度为N的数组A[],求有多少对i, j, k(1<=i<j<k<=N)满足A[k]-A[j]=A[j]-A[i]。
Input
第一行一个整数N(N<=10^5)。
接下来一行N个数A[i](A[i]<=30000)。
Output
一行一个整数。
Sample Input
10
3 5 3 6 3 4 10 4 5 2
3 5 3 6 3 4 10 4 5 2
Sample Output
9
HINT
Source
妙啊,正常暴力思想可以枚举i,k,这样就可以n^2解决
当然我们可以发现这个很像FFT,但是似乎复杂度O(n^2logn)?
那么如何优化
考虑分块,分类讨论
- 如果i,j,k在一个块内,就可以直接暴力,复杂度O(块大小*n)
- 如果i,j或j,k,仍可以暴力,在块内枚举,预处理前缀信息,复杂度还是O(块大小*n)
- 否则就可以FFT,枚举每一块,左面与右面卷积,在中间枚举j,加起来就好了,复杂度O(n^2logn/块大小)
因为FFT有log,所以块大小应该设大一些,,设为2000比较合适?
(不过我常数好大跑的好慢啊qwq
#include<map>
#include<cmath>
#include<queue>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define inf 1000000007
#define ll long long
#define N 130010
#define PI acos(-1)
inline int rd()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
struct cp{
double r,i;
cp(double _r=0,double _i=0):r(_r),i(_i){}
cp operator + (cp x){return cp(r+x.r,i+x.i);}
cp operator - (cp x){return cp(r-x.r,i-x.i);}
cp operator * (cp x){return cp(r*x.r-i*x.i,r*x.i+i*x.r);}
void clr(){r=i=0.0;}
};
int a[N];
ll ans;
int r[N],l=-1;
cp A[N],B[N];
void FFT(cp *x,int n,int f)
{
int i,j,k;
for(i=0;i<n;i++) if(i<r[i]) swap(x[i],x[r[i]]);
for(i=1;i<n;i<<=1)
{
cp wn(cos(PI/i),f*sin(PI/i));
for(j=0;j<n;j+=(i<<1))
{
cp X,Y,w(1,0);
for(k=0;k<i;k++,w=w*wn)
{
X=x[k+j];Y=w*x[k+j+i];
x[k+j]=X+Y;x[k+j+i]=X-Y;
}
}
}
if(f==-1) for(i=0;i<n;i++) x[i].r/=n;
}
int ji[N];
int n,nn,tot,bk=2000,L[N],R[N];
void sol(int p)
{
for(int i=0;i<nn;i++) A[i].clr(),B[i].clr();
for(int i=1;i<L[p];i++) A[a[i]].r+=1;
for(int i=R[p]+1;i<=n;i++) B[a[i]].r+=1;
FFT(A,nn,1);FFT(B,nn,1);
for(int i=0;i<nn;i++) A[i]=A[i]*B[i];
FFT(A,nn,-1);
for(int i=L[p];i<=R[p];i++) ans+=(int)(A[a[i]<<1].r+0.1);
}
int main()
{
n=rd();
for(int i=1;i<=n;i++) a[i]=rd();
for(nn=1;nn<=60000;nn<<=1)l++;
for(int i=0;i<nn;i++) r[i]=(r[i>>1]>>1)|((i&1)<<l);
while(tot*bk<n) L[++tot]=R[tot-1]+1,R[tot]=tot*bk;
R[tot]=min(R[tot],n);
// i,j,k in block
for(int p=1;p<=tot;p++)
{
for(int i=L[p],k;i<R[p];i++)
{
for(k=i+2;k<=R[p];k++)
{
ji[a[k-1]]++;
if(!((a[i]+a[k])&1))ans+=ji[(a[i]+a[k])>>1];
}
for(k=i+1;k<R[p];k++) ji[a[k]]--;
}
}
//i,j in block
memset(ji,0,sizeof(ji));
for(int i=1;i<=n;i++) ji[a[i]]++;
for(int p=1;p<tot;p++)
{
for(int i=L[p];i<=R[p];i++) ji[a[i]]--;
for(int i=L[p],j;i<R[p];i++)
for(j=i+1;j<=R[p];j++)
if(a[j]*2-a[i]>=0)ans+=ji[a[j]*2-a[i]];
}
//j,k in block
memset(ji,0,sizeof(ji));
for(int i=1;i<=n;i++) ji[a[i]]++;
for(int p=tot;p;p--)
{
for(int i=L[p];i<=R[p];i++) ji[a[i]]--;
for(int j=L[p],k;j<R[p];j++)
for(k=j+1;k<=R[p];k++)
if(a[j]*2-a[k]>=0) ans+=ji[a[j]*2-a[k]];
}
//j in block FFT
for(int i=2;i<tot;i++) sol(i);
printf("%lld\n",ans);
return 0;
}
发表评论