矩阵乘法优化

bbqub

概念

矩阵相乘最重要的方法是一般矩阵乘积。它只有在第一个矩阵的列数(column)和第二个矩阵的行数(row)相同时才有意义[1] 。一般单指矩阵乘积时,指的便是一般矩阵乘积。一个m×n的矩阵就是m×n个数排成m行n列的一个数阵。由于它把许多数据紧凑的集中到了一起,所以有时候可以简便地表示一些复杂的模型。

bg2015090104.png

这个结果是怎么算出来的?
教科书告诉你,计算规则是,第一个矩阵第一行的每个数字(2和1),各自乘以第二个矩阵第一列对应位置的数字(1和1),然后将乘积相加(2 x 1 + 1 x 1),得到结果矩阵左上角的那个值3。

bg2015090105.gif

也就是说,结果矩阵第m行与第n列交叉位置的那个值,等于第一个矩阵第m行与第二个矩阵第n列,对应位置的每个值的乘积之和

bg2015090106.png
bg2015090107.png

有三组未知数 x、y 和 t,其中 x 和 y 的关系如下
bg2015090108.png

理解

矩阵快速幂:
F(0) = 0
F(1) = 1
F(n) = F(n - 1) + F(n - 2) (n >= 2)
(1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, ...)
给出n,求F(n),由于结果很大,输出F(n) % 1000000007的结果即可。

分析 :

斐波那契数列的递推式为 f(n) = f(n-1)+f(n-2) ,直接循环求出 f(n) 的时间复杂度是 O(n) ,对于题目中的数据范围显然无法承受。很明显我们需要对数级别的算法。由于 f(n) = 1*f(n-1) + 1*f(n-2) 这样的形式很类似于矩阵的乘法,所以我们可以先把这个问题复杂化一下,将递推求解 f(n)f(n-1) 的过程看作是某两个矩阵相乘的结果,式子如下:

图片描述

所以我们只要不断地乘以上面式子中的第二个矩阵(也就是第二个矩阵的幂)就能够不断递推得到 f(n) 。但是这样于解题没有丝毫益处,反而使得常数变得更大(矩阵乘法的复杂度为立方级别)。所以我们就要利用矩阵乘法的一条重要性质:结合律。即矩阵 (A*B)*C = A*(B*C) ,证明过程可参见 2008 年国家集训队俞华程的论文。

图片描述

有了结合律我们就可以用快速幂计算矩阵的幂,问题的复杂度顺利降到了 O(logn)

代码

#include<iostream>    
#include<memory.h>    
#include<cstdlib>    
#include<cstdio>    
#include<cmath>    
#include<cstring>    
#include<string>    
#include<cstdlib>    
#include<iomanip>    
#include<vector>    
#include<list>    
#include<map>    
#include<algorithm>    
typedef long long LL;    
const LL maxn=1000+10;  
const LL mod=1000000007;  
const int N=2;  
using namespace std;   
struct Matrix  
{  
    LL m[N][N];  
};  
Matrix A=  
{  
    1,1,  
    1,0  
};  
Matrix I=  
{  
    1,0,  
    0,1  
};  
Matrix multi(Matrix a,Matrix b)  
{  
    Matrix c;  
    for(int i=0;i<N;i++)  
    {  
        for(int j=0;j<N;j++)  
        {  
            c.m[i][j]=0;  
            for(int k=0;k<N;k++)  
                c.m[i][j]+=a.m[i][k]*b.m[k][j]%mod;  
                      
            c.m[i][j]%=mod;  
        }  
    }  
    return c;  
}  
Matrix power(Matrix A,int k)  
{  
    Matrix ans=I,p=A;  
    while(k)  
    {  
        if(k&1)  
        {  
            ans=multi(ans,p);  
            k--;  
        }  
        k>>=1;  
        p=multi(p,p);  
    }  
    return ans;  
}  
int main()  
{  
    int n;  
    while(~scanf("%d",&n))  
    {  
        Matrix ans =power(A,n-1);  
        printf("%lld\n",ans,m[0][0]);  
    }  
    return 0;  
}  
阅读 6k
6 声望
0 粉丝
0 条评论
6 声望
0 粉丝
文章目录
宣传栏