loj #6059. 「2017 山东一轮集训 Day1」Sum — 倍增+NTT
Contents
#6059. 「2017 山东一轮集训 Day1」Sum
内存限制:256 MiB 时间限制:1500 ms
题目描述
求有多少 n 位十进制数是 p pp 的倍数且每位之和小于等于 mi(mi=0,1,2,…,m−1,m) ,允许前导 0,答案对 998244353 取模。
输入格式
一行三个整数 n,p,m。
输出格式
输出一行 m+1个正整数,分别表示 mi=0,1,2,…,m−1,m 时的答案。
样例
样例输入
2 3 3
样例输出
1 1 1 5
数据范围与提示
首先裸dp比较好推,f[i][j][k]表示第i位,模p为j,数字和为k的方案数
由于n很大,我们可以考虑二进制拆分
倍增求出f[2^i][][],然后将n的二进制位是1的合并起来就好了
然后就是考虑如何合并两个数组
可以先枚举k,k=k1+k2,然后p^2的枚举每一个余数就好
我们不难发现这就是一个卷积
用NTT加速即可
这里复杂度是(p*mlogm+p^2 m)的,因为我们只需要一次正变换,然后乘完之后再变换回来即可
所以总复杂度(p*mlogm+p^2 m)logn
#include<map>
#include<cmath>
#include<queue>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define inf 1000000007
#define mod 998244353
#define ll long long
#define N 4010
int r[N],L=-1,G=3,nn;
ll inv;
ll ksm(ll a,int b)
{
ll sum=1;
while(b)
{
if(b&1) sum=sum*a%mod;
a=a*a%mod;b>>=1;
}
return sum;
}
ll t,f[55][N],g[55][N],h[55][N];
void NTT(ll *x,int f)
{
int i,j,k;
ll wn,w,X,Y;
for(i=0;i<nn;i++) if(i<r[i]) swap(x[i],x[r[i]]);
for(i=1;i<nn;i<<=1)
{
wn=ksm(G,(mod-1)/(i<<1));
for(j=0;j<nn;j+=(i<<1))
{
for(k=0,w=1;k<i;k++,w=w*wn%mod)
{
X=x[j+k];Y=w*x[j+k+i]%mod;
x[j+k]=(X+Y)%mod;x[j+k+i]=(X-Y+mod)%mod;
}
}
}
if(f==-1)
{
reverse(x+1,x+nn);
for(i=0;i<nn;i++) x[i]=x[i]*inv%mod;
}
}
int n,m,p;
void sol1()
{
register int i,j,k;
for(i=0;i<p;i++) NTT(f[i],1),NTT(g[i],1);
memset(h,0,sizeof(h));
for(i=0;i<p;i++)
for(j=0;j<p;j++)
for(k=0;k<nn;k++) (h[(t*i+j)%p][k]+=f[i][k]*g[j][k])%=mod;
for(i=0;i<p;i++)
{
NTT(h[i],-1),NTT(g[i],-1);
for(j=0;j<=m;j++) f[i][j]=h[i][j];
for(j=m+1;j<nn;j++) f[i][j]=0;
}
}
void sol2()
{
register int i,j,k;
for(i=0;i<p;i++) NTT(g[i],1);
memset(h,0,sizeof(h));
for(i=0;i<p;i++)
for(j=0;j<p;j++)
for(k=0;k<nn;k++) (h[(t*i+j)%p][k]+=g[i][k]*g[j][k])%=mod;
for(i=0;i<p;i++)
{
NTT(h[i],-1);
for(j=0;j<=m;j++) g[i][j]=h[i][j];
for(j=m+1;j<nn;j++) g[i][j]=0;
}
}
int main()
{
scanf("%d%d%d",&n,&p,&m);
for(nn=1;nn<=m*2;nn<<=1) L++; inv=ksm(nn,mod-2);
for(int i=0;i<nn;i++) r[i]=(r[i>>1]>>1)|((i&1)<<L);
t=10;f[0][0]=1;
for(int i=0;i<=9&&i<=m;i++) g[i%p][i]++;
while(n)
{
if(n&1) sol1();
sol2();n>>=1;t=t*t%p;
}
for(int i=1;i<=m;i++) (f[0][i]+=f[0][i-1])%=mod;
for(int i=0;i<=m;i++) printf("%lld%c",f[0][i]," \n"[i==m]);
return 0;
}

发表评论