loj #6059. 「2017 山东一轮集训 Day1」Sum — 倍增+NTT
Contents
#6059. 「2017 山东一轮集训 Day1」Sum
内存限制:256 MiB 时间限制:1500 ms
题目描述
求有多少 n 位十进制数是 p pp 的倍数且每位之和小于等于 mi(mi=0,1,2,…,m−1,m) ,允许前导 0,答案对 998244353 取模。
输入格式
一行三个整数 n,p,m。
输出格式
输出一行 m+1个正整数,分别表示 mi=0,1,2,…,m−1,m 时的答案。
样例
样例输入
2 3 3
样例输出
1 1 1 5
数据范围与提示
首先裸dp比较好推,f[i][j][k]表示第i位,模p为j,数字和为k的方案数
由于n很大,我们可以考虑二进制拆分
倍增求出f[2^i][][],然后将n的二进制位是1的合并起来就好了
然后就是考虑如何合并两个数组
可以先枚举k,k=k1+k2,然后p^2的枚举每一个余数就好
我们不难发现这就是一个卷积
用NTT加速即可
这里复杂度是(p*mlogm+p^2 m)的,因为我们只需要一次正变换,然后乘完之后再变换回来即可
所以总复杂度(p*mlogm+p^2 m)logn
#include<map> #include<cmath> #include<queue> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> using namespace std; #define inf 1000000007 #define mod 998244353 #define ll long long #define N 4010 int r[N],L=-1,G=3,nn; ll inv; ll ksm(ll a,int b) { ll sum=1; while(b) { if(b&1) sum=sum*a%mod; a=a*a%mod;b>>=1; } return sum; } ll t,f[55][N],g[55][N],h[55][N]; void NTT(ll *x,int f) { int i,j,k; ll wn,w,X,Y; for(i=0;i<nn;i++) if(i<r[i]) swap(x[i],x[r[i]]); for(i=1;i<nn;i<<=1) { wn=ksm(G,(mod-1)/(i<<1)); for(j=0;j<nn;j+=(i<<1)) { for(k=0,w=1;k<i;k++,w=w*wn%mod) { X=x[j+k];Y=w*x[j+k+i]%mod; x[j+k]=(X+Y)%mod;x[j+k+i]=(X-Y+mod)%mod; } } } if(f==-1) { reverse(x+1,x+nn); for(i=0;i<nn;i++) x[i]=x[i]*inv%mod; } } int n,m,p; void sol1() { register int i,j,k; for(i=0;i<p;i++) NTT(f[i],1),NTT(g[i],1); memset(h,0,sizeof(h)); for(i=0;i<p;i++) for(j=0;j<p;j++) for(k=0;k<nn;k++) (h[(t*i+j)%p][k]+=f[i][k]*g[j][k])%=mod; for(i=0;i<p;i++) { NTT(h[i],-1),NTT(g[i],-1); for(j=0;j<=m;j++) f[i][j]=h[i][j]; for(j=m+1;j<nn;j++) f[i][j]=0; } } void sol2() { register int i,j,k; for(i=0;i<p;i++) NTT(g[i],1); memset(h,0,sizeof(h)); for(i=0;i<p;i++) for(j=0;j<p;j++) for(k=0;k<nn;k++) (h[(t*i+j)%p][k]+=g[i][k]*g[j][k])%=mod; for(i=0;i<p;i++) { NTT(h[i],-1); for(j=0;j<=m;j++) g[i][j]=h[i][j]; for(j=m+1;j<nn;j++) g[i][j]=0; } } int main() { scanf("%d%d%d",&n,&p,&m); for(nn=1;nn<=m*2;nn<<=1) L++; inv=ksm(nn,mod-2); for(int i=0;i<nn;i++) r[i]=(r[i>>1]>>1)|((i&1)<<L); t=10;f[0][0]=1; for(int i=0;i<=9&&i<=m;i++) g[i%p][i]++; while(n) { if(n&1) sol1(); sol2();n>>=1;t=t*t%p; } for(int i=1;i<=m;i++) (f[0][i]+=f[0][i-1])%=mod; for(int i=0;i<=m;i++) printf("%lld%c",f[0][i]," \n"[i==m]); return 0; }
发表评论