Strassen矩阵乘法——C++

【题目描述】

根据课本“Strassen矩阵乘法”的基本原理,设计并实现一个矩阵快速乘法的工具。并演示至少10000维的矩阵快速乘法对比样例。

【功能要求】
  1. 实现普通矩阵乘法算法和“Strassen矩阵乘法”算法
  2. 对相同的矩阵,分别用普通矩阵乘法算法,“Strassen矩阵乘法”算法和Matlab进行运算,比较时间差异(多次计算求平均值);
【选做功能】
  1. 突破2n的维数限制,能够对其他维数的矩阵进行运算。
  2. 方法不限,实现尽可能快的矩阵计算。
  3. 其他可扩展的功能。
【实验过程】
  1. 首先我们先设计实现普通的矩阵乘法,对于两个矩阵,普通的矩阵相乘做法是:遍历三层矩阵计算:我们设A和B是2个n*n的矩阵,它们的乘积AB同样是一个n*n矩阵。 A和B的乘积矩阵C中元素C[i][j]定义为:

比如,我们以下列的例子作为参考:对于它们的乘积,我们应该使用公式:

所以,从上述的公式中,我们知道如果使用这正常的矩阵相乘,由此得出:

所以我们的计算的时间复杂度是O(n^3)。

计算的代码为:对于数据的输入,我们使用的是将数据存储在data.txt中,每次去读取这个文件中的矩阵规模n和矩阵 arr1[][] 和 arr2[][]

#include<iostream>
#include<time.h>
#include "fstream"
void Multiply(int pInt, long long **pInt1, long long **pInt2, long long **pInt3);

void out(int pInt, long long **pInt1);

using namespace std;

int main() {
    system("chcp 65001 > nul");

    std::ios::sync_with_stdio(false);
    std::cin.tie(0);
//    c++加速流
    int M;
    fstream f;
    f.open("data.txt",ios::in);
    f >> M;

    int length = M;

    if (M % 2 != 0) //若M为奇数,则补零
    {
        length++;
    }

    long long **A = new long long *[length];
    long long **B = new long long *[length];
    long long **C = new long long *[length];

    for (int i = 0; i < length; i++) {
        A[i] = new long long[length];
        B[i] = new long long[length];
        C[i] = new long long[length];
    }

    for (int i = 0; i < M; i++) {
        for (int j = 0; j < M; j++)
            f >> A[i][j];
    }
    for (int i = 0; i < M; i++) {
        for (int j = 0; j < M; j++) {
            C[i][j] = 0;
            f >> B[i][j];
        }
    }

    clock_t start;
    clock_t end;
    start = clock();
    Multiply(M, A, B, C);
    end = clock();
    cout <<"当数据量n为"<<M<<"时,耗费的时间:"<< (end - start) << "ms" << endl;  //输出时间(单位:ms)
//    out(M, C);

}

void out(int n, long long **arr) {
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            cout << arr[i][j] << " ";
        }
        cout << endl;
    }
}

void Multiply(int n, long long **A, long long **B, long long **C) {
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            for (int k = 0; k < n; k++) {
                C[i][j] += A[i][k] * B[k][j];
            }
        }
    }
}

2. 观察这个算法之后,我们发现,在计算矩阵相乘的时候,时间复杂度达到了O(n^3)。 如果n过于大的话,需要计算很久才会出结果。对于10000 × 10000的数据量二维数组存储的话会爆栈。因此我们使用更加高效的算法Strassen矩阵乘法。

3. 1969年,Volker Strassen提出了第一个算法时间复杂度低于O(n^3的矩阵乘法算法,算法复杂度为,还是很接近3,因此StrassenStrassen 算法只有在对于维数比较大的矩阵,性能上才可能有优势,可以减少很多乘法计算。StrassenStrassen 算法证明了矩阵乘法存在时间复杂度低于O(n^3的算法的存在,后续学者不断研究发现新的更快的算法,截止目前时间复杂度最低的矩阵乘法算法是 Coppersmith-Winograd 方法的一种扩展方法,其算法复杂度为

4. Strassen原理详解:

假设矩阵 A 和矩阵 B 都是  N×N (N = 2^n)的方矩阵,求 C = AB,如下所示:

其中,8T(2n​) 表示 8 次矩阵乘法,而且相乘的矩阵规模降到了 2n​。

O()表示 4 次矩阵加法的时间复杂度以及合并矩阵 C 的时间复杂度。

最终可计算得到  T(n)= O()。

可以看出每次递归操作都需要 8 次矩阵相乘,而这正是瓶颈的来源。相比加法,矩阵乘法是非常慢的,于是减少矩阵相乘的次数就显得尤为重要。Strassen算法的主要目的其实也是从这个角度出发的,目的就是减少乘法次数,降低时间复杂度。

5. Strassen的实现步骤:

① 对于上述的A、B、C三个矩阵进行分解,分解花费的时间复杂度是O(1)

② 然后我们创建如下的10个 ×  的矩阵 S1 ,S2 ,S3 …… S10 ,花费的时间复杂大约是O(

③ 接下来递归计算七个矩阵P1 ,P2 ,P3 …… P7 每个P都是 n2 × n2 的矩阵。

④ 接着通过Pi 来计算C11 C12 C21 C22 ,花费的时间为O()。

这样就相对减少了一些时间复杂度。代码如下:

#include <iostream>
#include <time.h>
#include <fstream>
void out(int m, int **pInt);

using namespace std;

void subMatrix(int l, long long **m, long long **n, long long **ans) {
    for (int i = 0; i < l; i++) {
        for (int j = 0; j < l; j++) {
            ans[i][j] = m[i][j] - n[i][j];
        }
    }
}

void addMatrix(int l, long long **m, long long **n, long long **ans) //两矩阵加法
{
    for (int i = 0; i < l; i++) {
        for (int j = 0; j < l; j++) {
            ans[i][j] = m[i][j] + n[i][j];
        }
    }
}

void multiMatrix(int l, long long **m, long long **n, long long **ans) {
    for (int i = 0; i < l; i++) {
        for (int j = 0; j < l; j++) {
            ans[i][j] = 0;
            for (int k = 0; k < l; k++) {
                ans[i][j] += m[i][k] * n[k][j];
            }
        }
    }
}

void Strassen(int M, long long **A, long long **B, long long **C) {
    int len = M / 2;
    long long **A11 = new long long  *[len];
    long long **A12 = new long long  *[len];
    long long **A21 = new long long  *[len];
    long long **A22 = new long long  *[len];
    long long **B11 = new long long  *[len];
    long long **B12 = new long long  *[len];
    long long **B21 = new long long  *[len];
    long long **B22 = new long long  *[len];
    long long **C11 = new long long  *[len];
    long long **C12 = new long long  *[len];
    long long **C21 = new long long  *[len];
    long long **C22 = new long long  *[len];

    long long **P1 = new  long long *[len];
    long long **P2 = new  long long *[len];
    long long **P3 = new  long long *[len];
    long long **P4 = new  long long *[len];
    long long **P5 = new  long long *[len];
    long long **P6 = new  long long *[len];
    long long **P7 = new  long long *[len];

    long long **AR = new long long *[len];
    long long **BR = new long long *[len];

    for (int i = 0; i < len; i++) {
        A11[i] = new long long [len];
        A12[i] = new long long [len];
        A21[i] = new long long [len];
        A22[i] = new long long [len];
        B11[i] = new long long [len];
        B12[i] = new long long [len];
        B21[i] = new long long [len];
        B22[i] = new long long [len];
        C11[i] = new long long [len];
        C12[i] = new long long [len];
        C21[i] = new long long [len];
        C22[i] = new long long [len];
        P1[i] = new  long long [len];
        P2[i] = new  long long [len];
        P3[i] = new  long long [len];
        P4[i] = new  long long [len];
        P5[i] = new  long long [len];
        P6[i] = new  long long [len];
        P7[i] = new  long long [len];
        AR[i] = new  long long [len];
        BR[i] = new  long long [len];
    }

    for (int i = 0; i < len; i++) {
        for (int j = 0; j < len; j++) {
            A11[i][j] = A[i][j];
            A12[i][j] = A[i][j + len];
            A21[i][j] = A[i + len][j];
            A22[i][j] = A[i + len][j + len];

            B11[i][j] = B[i][j];
            B12[i][j] = B[i][j + len];
            B21[i][j] = B[i + len][j];
            B22[i][j] = B[i + len][j + len];
        }
    }
    addMatrix(len, A11, A22, AR);
    addMatrix(len, B11, B22, BR);
    multiMatrix(len, AR, BR, P1);

    addMatrix(len, A21, A22, AR);
    multiMatrix(len, AR, B11, P2);

    subMatrix(len, B12, B22, BR);
    multiMatrix(len, A11, BR, P3);

    subMatrix(len, B21, B11, BR);
    multiMatrix(len, A22, BR, P4);

    addMatrix(len, A11, A12, AR);
    multiMatrix(len, AR, B22, P5);

    subMatrix(len, A21, A11, AR);
    addMatrix(len, B11, B12, BR);
    multiMatrix(len, AR, BR, P6);

    subMatrix(len, A12, A22, AR);
    addMatrix(len, B21, B22, BR);
    multiMatrix(len, AR, BR, P7);

    addMatrix(len, P1, P4, AR);
    subMatrix(len, P7, P5, BR);
    addMatrix(len, AR, BR, C11);

    addMatrix(len, P3, P5, C12);

    addMatrix(len, P2, P4, C21);

    addMatrix(len, P1, P3, AR);
    subMatrix(len, P6, P2, BR);
    addMatrix(len, AR, BR, C22);

    for (int i = 0; i < len; i++) {
        for (int j = 0; j < len; j++) {
            C[i][j] = C11[i][j];
            C[i][j + len] = C12[i][j];
            C[i + len][j] = C21[i][j];
            C[i + len][j + len] = C22[i][j];
        }
    }
}

int main() {
    system("chcp 65001 > nul");
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);
//    c++加速流
    int M;

    fstream f;
    f.open("data.txt",ios::in);
    f >> M;
    int length = M;

    if (M % 2 != 0) //若M为奇数,则补零
    {
        length++;
    }

    long long **A = new long long *[length];
    long long **B = new long long *[length];
    long long **C = new long long *[length];

    for (int i = 0; i < length; i++) {
        A[i] = new long long [length];
        B[i] = new long long [length];
        C[i] = new long long [length];
    }
    for (int i = 0; i < M; i++) {
        for (int j = 0; j < M; j++)
            f >> A[i][j];
    }
    for (int i = 0; i < M; i++) {
        for (int j = 0; j < M; j++) {
            f >> B[i][j];
        }
    }

    if (length > M) {
        for (int i = 0; i < length; i++) {
            A[i][M] = 0;
            A[M][i] = 0;
            B[i][M] = 0;
            B[M][i] = 0;
        }
    }

    clock_t start;
    clock_t end;
    start = clock();
    Strassen(length, A, B, C);
    end = clock();
    cout <<"当数据量n为"<<M<<"时,耗费的时间:"<< (end - start) << "ms" << endl;  //输出时间(单位:ms)
// 输出
//    out(M,C);

    return 0;
}

void out(int M, int **C) {
    for (int i = 0; i < M; i++)
    {
        for (int j = 0; j < M; j++)
        {
            cout << C[i][j] << " \n"[j == M - 1];
        }
    }
}

接着,我们通过改变数据量的大小,来比较这两个算法的耗时。

对于测试数据的生成,我们使用makeData.cpp来生成并保存到文件data.txt。代码如下:

//简单的随机制造数据
#include<iostream>
#include <ctime>
#include "stdlib.h"
#include "fstream"
using namespace std;
// 左闭右闭区间
int getRand(int min, int max) {
    return (rand() % (max - min + 1)) + min;
}

int main() {
    int n;
    cin >> n;
    fstream f;
    f.open("data.txt", ios::out);
    f << n << endl;
    srand(time(0));
    for (int i = 0; i < 2 * n; i++) {
        for (int j = 0; j < n; j++) {
            f << getRand(0, 10) << " ";
        }
        f << endl;
    }
    f.close();
    return 0;
}

我们使用了上述的矩阵生成代码,随机创建了10000×10000大小的矩阵进行测试,如下图所示:

计算得到结果:

我们再使用Matlab来计算一下两个矩阵相乘的耗时:

统计得到:(其中的数据都是由3次统计求平均值的方式得来的。)

数据量

10

50

100

500

1000

2000

3000

普通

0

0.7

17.5

1644.667

20423.67

209913.5

779112

Strassen

0

2.5

11.5

1426

11692.67

90211.33

521274

matlab

0

0.185

0.328

2.926

16.38

237.273

407.674

【实验结论】

最后,比较得出结论:

1. 在矩阵规模较小的情况下,(例如 n<64),普通的矩阵相乘算法表现更优,耗时更短。

2. 当矩阵规模较大时,Strassen算法表现更优,耗时更短。因为在矩阵规模较大时,Strassen算法所需的递归次数相对较少,而且该算法每一次递归所做的运算规模较小,这些都有利于提高运算效率。

3. 在Matlab中,可以使用自带的矩阵乘法函数*来进行矩阵相乘运算,该函数会根据矩阵规模和计算机硬件等情况自动选择最优算法进行计算。因此,在实际应用中,建议使用内置的矩阵乘法函数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/582289.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

电机控制系列模块解析(11)—— 电流采样

一、电流采样分类 由下图可知&#xff0c;采样电阻的位置不同&#xff0c;电流采样分为输出电流采样、下桥电流采样、母线电流采样。 输出电流采样 定义&#xff1a;输出电流采样是指对电机定子绕组或转子绕组&#xff08;对于内转子永磁同步电机&#xff09;输出的电流进行测…

什么是区块链?智能合约有什么用?

一、什么是区块链&#xff1f; 区块链是一种去中心化的分布式账本技术&#xff0c;通过加密和共识机制确保数据的安全和透明。它将交易数据按照时间顺序记录在区块中&#xff0c;并通过链式链接保证了数据的不可篡改性。 二、什么是智能合约&#xff1f; 智能合约是运行在区…

如何修改php版本

我使用的Hostease的Windows虚拟主机产品,由于网站程序需要支持高版本的PHP,程序已经上传到主机&#xff0c;但是没有找到切换PHP以及查看PHP有哪些版本的位置&#xff0c;因此咨询了Hostease的技术支持&#xff0c;寻求帮助了解到可以实现在Plesk面板上找到此切换PHP版本的按钮…

linux tcpdump的交叉编译以及使用

一、源码下载 官网&#xff1a;点击跳转 二、编译 1、解压 tar -xf libpcap-1.10.4.tar.xz tar -xf tcpdump-4.99.4.tar.xz 2、配置及编译 //libpcap&#xff1a; ./configure --hostarm-linux --targetarm-linux CCarm-linux-gcc --with-pcaplinux --prefix$PWD/build//t…

37 线程控制

内核中没有明确的线程的概念&#xff0c;线程作为轻量级进程。所以不会提供线程的系统调用&#xff0c;只提供了轻量级进程的系统调用&#xff0c;但这个接口比较复杂&#xff0c;使用很不方便&#xff0c;我们用户&#xff0c;需要一个线程的接口。应用层对轻量级进程的接口进…

企业如何保证内部传输文件使用的工具是安全的?

企业内部文件的频繁交换成为了日常运营不可或缺的一环。然而&#xff0c;随着数据量的爆炸式增长和网络攻击手段的日益复杂&#xff0c;内网文件传输的安全隐患也日益凸显&#xff0c;成为企业信息安全的薄弱环节。本文将探讨内网文件传输的安全风险、企业常用的防护措施。 内网…

Python轻量级Web框架Flask(12)—— Flask类视图实现前后端分离

0、前言&#xff1a; 在学习类视图之前要了解前后端分离的概念&#xff0c;相对于之前的模板&#xff0c;前后端分离的模板会去除views文件&#xff0c;添加两个新python文件apis和urls&#xff0c;其中apis是用于传输数据和解析数据 的&#xff0c;urls是用于写模板路径的。 …

终于有人把无人机5G通信原理讲清楚了

在现代科技快速发展的背景下&#xff0c;无人机技术在各个领域都有了广泛应用&#xff0c;从送外卖到农业监控&#xff0c;无人机正变得越来越普遍。然而&#xff0c;无人机的效能很大程度上受到其通信系统的限制&#xff0c;尤其是在城市这种高楼林立、障碍物众多的环境中。为…

五一旅游必备物品清单 建议把这份清单记在备忘录

五一小长假就要来临&#xff0c;相信很多人已经跃跃欲试&#xff0c;准备带着家人或朋友外出旅游&#xff0c;享受这难得的休闲时光。出游总是让人兴奋不已&#xff0c;但带小孩出游&#xff0c;行李准备可是一项大工程。为了让旅程更加顺利&#xff0c;提前列一份必备物品清单…

Python绘制3D曲面图

&#x1f47d;发现宝藏 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【点击进入巨牛的人工智能学习网站】。 探索Python中绘制3D曲面图的艺术 在数据可视化的世界中&#xff0c;3D曲面图是一种强大的工…

压铸机PQ控制阀比例放大器

压铸机PQ控制阀比例放大器是确保压铸机正常工作的重要组成部分&#xff0c;它通常由多种液压元件组成&#xff0c;负责提供动力和控制系统中各个部件的运动。液压系统通过液体&#xff08;通常是油&#xff09;传递压力能&#xff0c;以驱动机械装置工作。在压铸机中&#xff0…

ElasticSearch教程入门到精通——第五部分(基于ELK技术栈elasticsearch 7.x+8.x新特性)

ElasticSearch教程入门到精通——第五部分&#xff08;基于ELK技术栈elasticsearch 7.x8.x新特性&#xff09; 1. Elasticsearch集成1.1 框架集成-SpringData-整体介绍1.2 Spring Data Elasticsearch 介绍1.3 框架集成-SpringData-代码功能集成1.3.1 创建Maven项目1.3.2 修改po…

[C++] 类和对象 _ 剖析构造、析构与拷贝

一、构造函数 构造函数是特殊的成员函数&#xff0c;它在创建对象时自动调用。其主要作用是初始化对象的成员变量&#xff08;不是开辟空间&#xff09;。构造函数的名字必须与类名相同&#xff0c;且没有返回类型&#xff08;即使是void也不行&#xff09;。 在C中&#xff0…

Yolov5简单部署(使用自己的数据集)

一.注意事项 1.本文主要是引用大佬的文章&#xff08;侵权请联系&#xff0c;马上删除&#xff09;&#xff0c;做的工作为简单补充 二.正文 1.大体流程按照 准备&#xff1a;【简单易懂&#xff0c;一看就会】yolov5保姆级环境搭建_哔哩哔哩_bilibili 主要过程&#xff1…

Java | Leetcode Java题解之第55题跳跃游戏

题目&#xff1a; 题解&#xff1a; public class Solution {public boolean canJump(int[] nums) {int n nums.length;int rightmost 0;for (int i 0; i < n; i) {if (i < rightmost) {rightmost Math.max(rightmost, i nums[i]);if (rightmost > n - 1) {retu…

VitePress 构建的博客如何部署到 github 平台?

VitePress 构建的博客如何部署到 github 平台&#xff1f; 1. 新建 github 项目 2. 构建 VitePress 项目 2.1. 设置 config 中的 base 由于我们的项目名称为 vite-press-demo&#xff0c;所以我们把 base 设置为 /vite-press-demo/&#xff0c;需注意前后 / export default…

tidb离线本地安装及mysql迁移到tidb

一、背景&#xff08;tidb8.0社区版&#xff09; 信创背景下不多说好吧&#xff0c;从资料上查tidb和OceanBase“兼容”&#xff08;这个词有意思&#xff09;的比较好。 其实对比了很多数据库&#xff0c;有些是提供云服务的&#xff0c;有些“不像”mysql&#xff0c;综合考虑…

uniapp:K线图,支持H5,APP

使用KLineChart完成K线图制作,完成效果: 1、安装KLineChart npm install klinecharts2、页面中使用 <template><view class="index"><!-- 上方选项卡 --><view class="kline-tabs"><view :style="{color: current==ite…

Windows使用bat远程操作Linux并执行命令

背景&#xff1a;让客户可以简单在Windows中能自己执行 Linux中的脚本&#xff0c;傻瓜式操作&#xff01; 方法&#xff1a;做一个简单的bat脚本&#xff01;能远程连接到Linux&#xff0c;并执行Linux命令&#xff01;客户双击就能使用&#xff01; 1、原先上网查询到使用P…

深度学习:基于Keras框架,使用神经网络模型对葡萄酒类型进行预测分析

前言 系列专栏&#xff1a;机器学习&#xff1a;高级应用与实践【项目实战100】【2024】✨︎ 在本专栏中不仅包含一些适合初学者的最新机器学习项目&#xff0c;每个项目都处理一组不同的问题&#xff0c;包括监督和无监督学习、分类、回归和聚类&#xff0c;而且涉及创建深度学…
最新文章