这是表结构部分的最后一篇文章,我们来聊一聊稀疏矩阵的压缩方法以及快速转置。
这部分需要一点线性代数的知识,如果没学过线性代数的话可以先跳过这部分,等学习完线性代数再回来阅读~
矩阵
在编程语言中,矩阵其实可以看作一个二维数组,比如python的numpy中的矩阵,可以和二维数组自由地进行转换。
在C++中,我们虽然没有集成好的numpy库,但我们也可以自己实现一个矩阵。
1 2 3 4 5 6
| vector<vector<int>> matrix({ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 8, 5, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {3, 0, 9, 0, 2, 0, 0, 0, 0, 0}, {0, 7, 0, 0, 0, 0, 0, 0, 0, 0} });
|
这样,我们的矩阵就完成了。
稀疏矩阵
我们会发现,构造一个n*m大小的矩阵,就需要开辟n*m的空间,但如果矩阵中存在大量的相同元素(如以上例子),那就会造成大量的空间浪费。这时我们想,有没有一种算法可以对这种矩阵实现压缩,从而减少空间的浪费?
我们可以从一种特殊的矩阵来入手研究——稀疏矩阵。
矩阵中非零元素的个数远远小于矩阵元素的总数,并且非零元素的分布没有规律,通常认为矩阵中非零元素的总数比上矩阵所有元素总数的值小于等于0.05时,则称该矩阵为稀疏矩阵(sparse matrix) ,该比值称为这个矩阵的稠密度;
稀疏矩阵几乎产生于所有的大型科学工程计算领域,包括计算流体力学、统计物理、电路模拟、图像处理、纳米材料计算等。
由于严格的稀疏矩阵中非零数据太少,不利于我们学习,我就用上方例子中给出的那个矩阵来介绍。
构造稀疏矩阵
我们发现,图中只有六个非零元素(在实际的稀疏矩阵中占比会更少),那我们是不是只需要记录这六个元素的信息(没记录的部分全都为零嘛),就可以实现记录整个矩阵了呢?
对于矩阵中的元素,我们需要记录的信息只有:
1 2 3 4 5 6
| typedef struct { int row; int col; int val; } smData;
|
这样就可以很容易地构造出一个数组来存放这些信息,顺便再记录一下矩阵的行列数(不然输出的时候怎么知道有几行几列):
1 2 3 4 5 6
| vector<smData> sparseMatrix;
int row_size = 0; int col_size = 0;
|
这个数组的大小为6*3=18,远远小于用二维数组记录的5*10=50。
于是,这个稀疏矩阵类的构造函数部分就很容易写出来了:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| SparseMatrix::SparseMatrix(vector<vector<int>> matrix) { row_size = matrix.size(); for (int i = 0; i < row_size; i++) { col_size = matrix[i].size(); for (int j = 0; j < col_size; j++) { if (matrix[i][j]) { smData temp; temp.row = i; temp.col = j; temp.val = matrix[i][j]; sparseMatrix.push_back(temp); } } } }
|
这段代码应该很容易理解,无非就是遍历一遍矩阵,然后将非零元素转化为三元组,并且插入数据数组罢了。
此时的三元组数组:
很显然,这种构造方法只适用于类稀疏矩阵,如果是极其稠密的矩阵,反而会浪费更多的存储空间。
输出稀疏矩阵
回忆一下刚才构造稀疏矩阵的过程,我们是按照逐行的顺序压缩的,所以输出的时候也需要按行优先的顺序进行扫描,依次将三元组数组中的元素按坐标输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| ostream &operator<<(ostream &os, const SparseMatrix &sm) { int k = 0; for (int i = 0; i < sm.row_size; i++) { for (int j = 0; j < sm.col_size; j++) { if (i == sm.sparseMatrix[k].row && j == sm.sparseMatrix[k].col) { os << sm.sparseMatrix[k].val << ' '; k++; } else os << 0 << ' '; } os << endl; } return os; }
|
这边我们保存了一个k指针,用来在三元组数组中移动,每找到当前元素的对应坐标,就将其输出,并将光标后移,依次输出。
输出结果:
1 2 3 4 5
| 0 0 0 0 0 0 0 0 0 0 0 0 8 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 9 0 2 0 0 0 0 0 0 7 0 0 0 0 0 0 0 0
|
转置
转置是一个数学名词。直观来看,将A的所有元素绕着一条从第1行第1列元素出发的右下方45度的射线作镜面反转,即得到A的转置。一个矩阵M, 把它的第一行变成第一列,第二行变成第二列,…,最末一行变为最末一列, 从而得到一个新的矩阵N。 这一过程称为矩阵的转置。即矩阵A的行和列对应互换。
说简单点,就是对于每个元素,将其行坐标与列坐标互换。
那对于三元组来说,不是更加容易?只要将每个三元组的行列信息互换不就好了吗?
然而互换完之后……
乍一看好像没什么问题,但是输出一下……?完了,行信息不是有序的了,没法正常输出了。我们得想一个办法让其变成有序的。
其实这里可以直接用c++的sort()方法排序的,但是我要介绍另外一种更加快速的方法。
快速转置
其实,对于某个元素,我们只要先对其列坐标进行分析,判断其在转置后的三元组数组中应该出现在什么位置,就可以实现顺序转置了。
而这个位置要如何判断呢?
显然,我们只需要知道每一列中第一个元素在新数组中的出现位置(在转置后即变成每一行的第一个元素),然后在转置的过程中,每当该行添加一个元素,就将这个位置加一。
我们举个例子:在一开始的时候,转置前第2列中第1个元素应该出现在第2位(关于这个第2位是怎么来的,我们一会再聊),毕竟第0位和第1位要留给第0列和第1列的那两个元素嘛。
好,那这时我们开始处理原三元组中的第一个元素(1,2,8),发现它的列坐标为2,对应应该出现在新数组中的第2位,将其互换行列后放入。然后第2列的首位位置加一,变为第3位,此时如果还有第2列的元素出现,就将其放在新数组中的第3位。依此类推,能够实现所有元素的正确入组。
那么,这个每列对应的首位位置要怎么来呢……?其实很容易想到,只要用前一列的首位位置加上前一列的元素个数就可以得到。
记nz_num[]为每列非零元素的个数,nz_index[]为每列第一个非零元素在新数组中出现的位置,很容易得出公式:
1
| nz_index[i] = nz_index[i - 1] + nz_num[i - 1];
|
这样一来,我们的转置代码就呼之欲出了:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| void SparseMatrix::transposition() { nz_num.resize(col_size, 0); nz_index.resize(col_size, 0);
vector<smData> temp((int)sparseMatrix.size());
for (size_t i = 0; i < sparseMatrix.size(); i++) { nz_num[sparseMatrix[i].col]++; }
for (int i = 1; i < col_size; i++) { nz_index[i] = nz_index[i - 1] + nz_num[i - 1]; }
for (size_t i = 0; i < sparseMatrix.size(); i++) { int index = nz_index[sparseMatrix[i].col]; temp[index].col = sparseMatrix[i].row; temp[index].row = sparseMatrix[i].col; temp[index].val = sparseMatrix[i].val; nz_index[sparseMatrix[i].col]++; }
sparseMatrix = temp;
int tempCol = col_size; col_size = row_size; row_size = tempCol; }
|
可以看出,这个转置算法只需遍历两次三元组,再遍历一次储存非零元素首次出现位置的组,所需的时间很短。
最后放上完整代码:
完整代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
|
#include <bits/stdc++.h> using namespace std;
class SparseMatrix { typedef struct { int row; int col; int val; } smData;
vector<smData> sparseMatrix;
int row_size = 0; int col_size = 0;
vector<int> nz_num; vector<int> nz_index;
public: SparseMatrix(vector<vector<int>> matrix); void transposition(); friend ostream &operator<<(ostream &os, const SparseMatrix &sm); };
int main() { vector<vector<int>> matrix({{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 8, 5, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {3, 0, 9, 0, 2, 0, 0, 0, 0, 0}, {0, 7, 0, 0, 0, 0, 0, 0, 0, 0}}); SparseMatrix sm(matrix); cout << sm; sm.transposition(); cout << sm; }
SparseMatrix::SparseMatrix(vector<vector<int>> matrix) { row_size = matrix.size(); for (int i = 0; i < row_size; i++) { col_size = matrix[i].size(); for (int j = 0; j < col_size; j++) { if (matrix[i][j]) { smData temp; temp.row = i; temp.col = j; temp.val = matrix[i][j]; sparseMatrix.push_back(temp); } } } }
void SparseMatrix::transposition() { nz_num.resize(col_size, 0); nz_index.resize(col_size, 0);
vector<smData> temp((int)sparseMatrix.size());
for (size_t i = 0; i < sparseMatrix.size(); i++) { nz_num[sparseMatrix[i].col]++; }
for (int i = 1; i < col_size; i++) { nz_index[i] = nz_index[i - 1] + nz_num[i - 1]; }
for (size_t i = 0; i < sparseMatrix.size(); i++) { int index = nz_index[sparseMatrix[i].col]; temp[index].col = sparseMatrix[i].row; temp[index].row = sparseMatrix[i].col; temp[index].val = sparseMatrix[i].val; nz_index[sparseMatrix[i].col]++; }
sparseMatrix = temp;
int tempCol = col_size; col_size = row_size; row_size = tempCol; }
ostream &operator<<(ostream &os, const SparseMatrix &sm) { int k = 0; for (int i = 0; i < sm.row_size; i++) { for (int j = 0; j < sm.col_size; j++) { if (i == sm.sparseMatrix[k].row && j == sm.sparseMatrix[k].col) { os << sm.sparseMatrix[k].val << ' '; k++; } else os << 0 << ' '; } os << endl; } return os; }
|