bzoj 3509: [CodeChef] COUNTARI — 分块+FFT
Contents
3509: [CodeChef] COUNTARI
Time Limit: 40 Sec Memory Limit: 128 MB
Description
给定一个长度为N的数组A[],求有多少对i, j, k(1<=i<j<k<=N)满足A[k]-A[j]=A[j]-A[i]。
Input
第一行一个整数N(N<=10^5)。
接下来一行N个数A[i](A[i]<=30000)。
Output
一行一个整数。
Sample Input
10
3 5 3 6 3 4 10 4 5 2
3 5 3 6 3 4 10 4 5 2
Sample Output
9
HINT
Source
妙啊,正常暴力思想可以枚举i,k,这样就可以n^2解决
当然我们可以发现这个很像FFT,但是似乎复杂度O(n^2logn)?
那么如何优化
考虑分块,分类讨论
- 如果i,j,k在一个块内,就可以直接暴力,复杂度O(块大小*n)
- 如果i,j或j,k,仍可以暴力,在块内枚举,预处理前缀信息,复杂度还是O(块大小*n)
- 否则就可以FFT,枚举每一块,左面与右面卷积,在中间枚举j,加起来就好了,复杂度O(n^2logn/块大小)
因为FFT有log,所以块大小应该设大一些,,设为2000比较合适?
(不过我常数好大跑的好慢啊qwq
#include<map> #include<cmath> #include<queue> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> using namespace std; #define inf 1000000007 #define ll long long #define N 130010 #define PI acos(-1) inline int rd() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } struct cp{ double r,i; cp(double _r=0,double _i=0):r(_r),i(_i){} cp operator + (cp x){return cp(r+x.r,i+x.i);} cp operator - (cp x){return cp(r-x.r,i-x.i);} cp operator * (cp x){return cp(r*x.r-i*x.i,r*x.i+i*x.r);} void clr(){r=i=0.0;} }; int a[N]; ll ans; int r[N],l=-1; cp A[N],B[N]; void FFT(cp *x,int n,int f) { int i,j,k; for(i=0;i<n;i++) if(i<r[i]) swap(x[i],x[r[i]]); for(i=1;i<n;i<<=1) { cp wn(cos(PI/i),f*sin(PI/i)); for(j=0;j<n;j+=(i<<1)) { cp X,Y,w(1,0); for(k=0;k<i;k++,w=w*wn) { X=x[k+j];Y=w*x[k+j+i]; x[k+j]=X+Y;x[k+j+i]=X-Y; } } } if(f==-1) for(i=0;i<n;i++) x[i].r/=n; } int ji[N]; int n,nn,tot,bk=2000,L[N],R[N]; void sol(int p) { for(int i=0;i<nn;i++) A[i].clr(),B[i].clr(); for(int i=1;i<L[p];i++) A[a[i]].r+=1; for(int i=R[p]+1;i<=n;i++) B[a[i]].r+=1; FFT(A,nn,1);FFT(B,nn,1); for(int i=0;i<nn;i++) A[i]=A[i]*B[i]; FFT(A,nn,-1); for(int i=L[p];i<=R[p];i++) ans+=(int)(A[a[i]<<1].r+0.1); } int main() { n=rd(); for(int i=1;i<=n;i++) a[i]=rd(); for(nn=1;nn<=60000;nn<<=1)l++; for(int i=0;i<nn;i++) r[i]=(r[i>>1]>>1)|((i&1)<<l); while(tot*bk<n) L[++tot]=R[tot-1]+1,R[tot]=tot*bk; R[tot]=min(R[tot],n); // i,j,k in block for(int p=1;p<=tot;p++) { for(int i=L[p],k;i<R[p];i++) { for(k=i+2;k<=R[p];k++) { ji[a[k-1]]++; if(!((a[i]+a[k])&1))ans+=ji[(a[i]+a[k])>>1]; } for(k=i+1;k<R[p];k++) ji[a[k]]--; } } //i,j in block memset(ji,0,sizeof(ji)); for(int i=1;i<=n;i++) ji[a[i]]++; for(int p=1;p<tot;p++) { for(int i=L[p];i<=R[p];i++) ji[a[i]]--; for(int i=L[p],j;i<R[p];i++) for(j=i+1;j<=R[p];j++) if(a[j]*2-a[i]>=0)ans+=ji[a[j]*2-a[i]]; } //j,k in block memset(ji,0,sizeof(ji)); for(int i=1;i<=n;i++) ji[a[i]]++; for(int p=tot;p;p--) { for(int i=L[p];i<=R[p];i++) ji[a[i]]--; for(int j=L[p],k;j<R[p];j++) for(k=j+1;k<=R[p];k++) if(a[j]*2-a[k]>=0) ans+=ji[a[j]*2-a[k]]; } //j in block FFT for(int i=2;i<tot;i++) sol(i); printf("%lld\n",ans); return 0; }
发表评论