题目描述
求(n^1+n^2+n^3+......+n^m)%p的值。
样例输入
2 2 5
样例输出
1
数据范围
n,p<=10^8,m<=10^17
题解
先说一下,这道题通法是矩阵乘法(话说……noip考这个?)。linux机子上测的,只有快速幂30分。这里没打……50分的打法用到了分治算法,原式可化为(n^1+n^2+n^3+......+n^(m div 2))+n^(m div 2)*(n^1+n^2+n^3+......+n^(m div 2))(红色部分要根据m的奇偶性特判)对于n^k可以快速幂过。剩下两部分可以继续分治下去直到括号内只剩n^1。这样看似复杂度很低,但应为每次分治都需要走到底,所以只是一种比普通暴力优越一点的算法。
#include<cstdio> #include<cstring> #include<cstdlib> #include<iostream> #include<cmath> #include<algorithm> #define MAXN 100005 #define ll long long using namespace std; ll n,p,m,ans; ll ksm(ll x,ll y) { ll da=1; while(y>0) {if(y&1) da=(da*x)%p; x=(x*x)%p; y=y>>1; } return da; } ll work(ll s) { if(s==1) return n%p; ll sum,t; if(s%2==1) {t=work(s/2)%p; sum=(t+(ksm(n,s/2)%p)*(work((s+1)/2)%p)%p)%p; } else {t=work(s/2)%p; sum=(t+((ksm(n,s/2)%p)*t)%p)%p; } return sum; } int main() { freopen("calc.in","r",stdin); freopen("calc.out","w",stdout); scanf("%lld%lld%lld",&n,&m,&p); n=n%p; ans=work(m); printf("%lld\n",ans); return 0; }
但是,这种所发有一个很大的优化空间,即以k为关键字,记录(n^1+n^2+n^3+......+n^k)%p的值。这样能节省不少时间,我的做法是对k值进行hash。加上这个优化数据全部秒过……
#include<cstdio> #include<cstring> #include<cstdlib> #include<iostream> #include<cmath> #include<algorithm> #define MAXN 100005 #define ll long long #define SU 49747 #define ad 7 using namespace std; ll n,p,m,ans; ll sh[500002]; ll ksm(ll x,ll y) { ll da=1; while(y>0) {if(y&1) da=(da*x)%p; x=(x*x)%p; y=y>>1; } return da; } ll work(ll s) { int loc=(s%SU)*ad; if(sh[loc]!=-1) return sh[loc]; if(s==1) return n%p; ll sum,t; if(s%2==1) {t=work(s/2)%p; sum=(t+(ksm(n,s/2)%p)*(work((s+1)/2)%p)%p)%p; } else {t=work(s/2)%p; sum=(t+((ksm(n,s/2)%p)*t)%p)%p; } sh[loc]=sum; return sum; } int main() { freopen("calc.in","r",stdin); freopen("calc.out","w",stdout); scanf("%lld%lld%lld",&n,&m,&p); n=n%p; memset(sh,-1,sizeof(sh)); ans=work(m); printf("%lld\n",ans); return 0; }
以下是正解:矩阵乘法。
#include<cstdio> #include<cstring> #include<cstdlib> #include<iostream> #include<cmath> #include<algorithm> #define MAXN 100005 #define ll long long #define SU 49747 #define ad 7 using namespace std; ll n,p,m,answer; ll a[5],b[3][3],c[3][3]; void mul(ll A[3][3],ll B[3][3],ll ans[3][3]) { ll t[3][3]; int i,j,k; for(i=1;i<=2;i++) for(j=1;j<=2;j++) {t[i][j]=0; for(k=1;k<=2;k++) t[i][j]=(t[i][j]+A[i][k]*B[k][j]%p)%p; } for(i=1;i<=2;i++) for(j=1;j<=2;j++) ans[i][j]=t[i][j]; } void ksm(ll x) { int i,j; for(i=1;i<=2;i++) c[i][i]=1; while(x>0) {if(x&1) mul(b,c,c); mul(b,b,b); x=x>>1; } } int main() { freopen("calc.in","r",stdin); freopen("calc.out","w",stdout); scanf("%I64d%I64d%I64d",&n,&m,&p); n=n%p; if(m==1) {printf("%I64d\n",n); return 0;} a[1]=a[2]=n; b[1][1]=n; b[2][1]=1; b[2][2]=1; ksm(m-1); for(int i=1;i<=2;i++) answer=(answer+(a[i]*c[i][1])%p)%p; printf("%I64d\n",answer); return 0; }