现在的位置: 首页 > 综合 > 正文

Matrix Power Series—-矩阵乘法(二分)

2013年09月08日 ⁄ 综合 ⁄ 共 2883字 ⁄ 字号 评论关闭

题目:http://poj.org/problem?id=3233

A+A^2+A^3……+A^k

如果k是偶数的话,原式=(I+A^k/2)(A+A^2……A^k/2)。

如果k是奇数的话,原式=(1+A^(k/2+1))(A+A^2……+A^k/2)+A^(k/2+1)。

我是用数组直接实现的,比较麻烦,速度比较慢,1000+ms。

下面是我的代码:

#include <stdio.h>

int m,n,k;
int a[30][30];
int ans[30][30],s[30][30],I[30][30];

void init(int x[30][30])
{
  for(int i=0;i<n;i++)
    for(int j=0;j<n;j++)
     x[i][j]=0;
  for(int i=0;i<n;i++)
    for(int j=0;j<n;j++)
     if(i==j)
     x[i][j]=1;
}

void add(int x[30][30],int y[30][30],int z[30][30])
{
    for(int i=0;i<n;i++)
      for(int j=0;j<n;j++)
      z[i][j]=(x[i][j]+y[i][j])%m;
}

void multiply(int x[30][30],int y[30][30],int z[30][30])
{
    for(int i=0;i<n;i++)
      for(int j=0;j<n;j++)
         z[i][j]=0;
    for(int i=0;i<n;i++)
      for(int j=0;j<n;j++)
        for(int k=0;k<n;k++)
         z[i][j]=(z[i][j]+x[i][k]*y[k][j])%m;
}

void copy(int x[30][30],int y[30][30])
{
    for(int i=0;i<n;i++)
      for(int j=0;j<n;j++)
       y[i][j]=x[i][j];
}

void quickPow(int a[30][30],int z[30][30],int n)
{
    int x[30][30],y[30][30],b[30][30];
    copy(a,b);
    init(x);  init(z);
    while(n>0)
    {
        if(n%2==1)  multiply(b,x,z);
        n=n/2;
        copy(z,x);
        multiply(b,b,y);
        copy(y,b);
    }
}

void solve(int a[30][30],int k)
{
    int b[30][30],c[30][30];
    if(k==1)  { copy(a,ans); copy(a,s); return ; }
    solve(a,k/2);
    if(k%2==0)
    {
       quickPow(a,b,k/2);
       add(I,b,c);
       multiply(c,s,ans);
       copy(ans,s);
    }
    else
    {
        quickPow(a,b,k/2+1);
        add(I,b,c);
        multiply(c,s,ans);
        copy(ans,s);
        add(b,s,ans);
        copy(ans,s);
    }
}
int main()
{
    //freopen("D:\\a.txt","r",stdin);
    scanf("%d%d%d",&n,&k,&m);
    for(int i=0;i<n;i++)
     for(int j=0;j<n;j++)
      scanf("%d",&a[i][j]);
     init(s);  init(ans); init(I);
     solve(a,k);
     for(int i=0;i<n;i++)
     {
         for(int j=0;j<n;j++)
         printf("%d ",ans[i][j]);
         printf("\n");
     }
}

后来看了一些比较快的代码,200-ms,他们使用结构体实现的,都知道结构体可以作为函数的返回值,操作简单,速度相对就快了。

下面是看到的代码:

原文地址http://www.cnblogs.com/forever4444/archive/2009/05/12/1454736.html

代码:

#include <iostream>

#define MAX 33

using namespace std;

typedef struct node

{

    int matirx[MAX][MAX];

}Matrix;

Matrix a,sa,unit;

int n,k,m;

void Init()

{

    int i,j;

    for(i=0;i<n;i++)

        for(j=0;j<n;j++)

        {

            scanf("%d",&a.matirx[i][j]);

            a.matirx[i][j]%=m;   //初始化要先%m

            unit.matirx[i][j]=(i==j);  //如果i==j那么矩阵中此值就是1,否则为0,就是主对角线是1的单位矩阵

        }

}

Matrix Add(Matrix a,Matrix b)  //矩阵加法

{

    Matrix c;

    int i,j;

    for(i=0;i<n;i++)

        for(j=0;j<n;j++)

        {

            c.matirx[i][j]=a.matirx[i][j]+b.matirx[i][j];

            c.matirx[i][j]%=m;   //加的时候也要%m

        }

    return c;

}

Matrix Mul(Matrix a,Matrix b)  //矩阵乘法

{

    Matrix c;

    int i,j,k;

    for(i=0;i<n;i++)

        for(j=0;j<n;j++)

        {

            c.matirx[i][j]=0;  //初始化矩阵c

            for(k=0;k<n;k++)

                c.matirx[i][j]+=a.matirx[i][k]*b.matirx[k][j];

            c.matirx[i][j]%=m;  //计算乘法的时候也要%m

        }

    return c;

}

Matrix Cal(int exp)   //矩阵幂

{

    Matrix p,q;

    p=a;  //p是初始矩阵

    q=unit;  //q是单位矩阵

    while(exp!=1)

    {

        if(exp&1)  //要求得幂是奇数

        {

            exp--;

            q=Mul(p,q);

        } 

        else    //要求的幂是偶数

        {

            exp>>=1;  //相当于除2

            p=Mul(p,p);

        }

    }

    p=Mul(p,q);

    return p;

}

Matrix MatrixSum(int k)

{

    if(k==1)  //做到最底层就将矩阵a返回就好

        return a;

    Matrix temp,tnow;

    temp=MatrixSum(k/2);

    if(k&1)  //如果k是奇数

    {

        tnow=Cal(k/2+1);

        temp=Add(temp,Mul(temp,tnow));

        temp=Add(tnow,temp);

    }  

    else    //如果k是偶数

    {

        tnow=Cal(k/2);

        temp=Add(temp,Mul(temp,tnow));

    }

    return temp;

}

int main()

{

    int i,j;

    while(scanf("%d%d%d",&n,&k,&m)!=EOF)

    {

        Init();

        sa=MatrixSum(k);

        for(i=0;i<n;i++)

        {

            for(j=0;j<n-1;j++)

            {

                printf("%d ",sa.matirx[i][j]%m);

            }

            printf("%d\n",sa.matirx[i][n-1]%m);

        }

    }

    return 0;

抱歉!评论已关闭.