Leetcode 311. Sparse Matrix Multiplication
思路
这道题最暴力的解法其实就是我们解决矩阵相乘的正常解法。我们需要用到三个loop去遍历两个矩阵的每一个元素并且做乘法和加法。但是这道题的特殊条件是sparse matrix,意味着这个矩阵可能包含很多0。而0和其他数相乘还是0,所以在这种情况下我们根本不用进行计算,直接略过就好了。
第一种做法就是我们在遍历A[i][k]的时候,先去check A[i][k]是否为0,如果不是我们才去拿B当前col的每一位去计算。这样的好处是很多情况下节省了我们第三个for loop。Leetcode下这种解法比使用HashMap的解法要快,尽管理论上应该是HashMap更快。原因可能是HashMap操作的overhead。
第二种解法和第三种解法类似,就是使用一个或者两个HashMap,只存储所有不为0的entry信息。而且我们实际上只需要去存储row and col index就好了,因为我们原来的matrix还在,我们取值的时候还是可以去原来的matrix去取。
第四种解法就是使用sort instead of using HashMap。这里对于每个矩阵先根据row排序,再根据col排序。真正走loop的时候逻辑是一样的,只不过这里我们就得手动区分出每一行和每一列呢,不如HashMap写起来直接。
代码
以下代码是使用两个HashMap的解法,理论上是最优的解法,但是Big-O的角度worst case还是O(n^3),因为我们不知道有多sparse。
time O(n^3), space O(n^2)
public class Solution {
public int[][] multiply(int[][] A, int[][] B) {
// check edge case
if (A == null || B == null) {
return new int[0][0];
}
// preprocess
int aRow = A.length;
int aCol = A[0].length;
int bRow = B.length;
int bCol = B[0].length;
int[][] result = new int[aRow][bCol];
HashMap<Integer, HashSet<Integer>> mapB = new HashMap<Integer, HashSet<Integer>>();
for (int k = 0; k < bRow; k++) {
HashSet<Integer> curSet = new HashSet<Integer>();
for (int j = 0; j < bCol; j++) {
if (B[k][j] != 0) {
curSet.add(j);
}
}
mapB.put(k, curSet);
}
HashMap<Integer, HashSet<Integer>> mapA = new HashMap<Integer, HashSet<Integer>>();
for (int i = 0; i < aRow; i++) {
HashSet<Integer> curSet = new HashSet<Integer>();
for (int k = 0; k < aCol; k++) {
if (A[i][k] != 0) {
curSet.add(k);
}
}
mapA.put(i, curSet);
}
// main loop
for (int i = 0; i < aRow; i++) {
for (int k : mapA.get(i)) {
for (int j : mapB.get(k)) {
result[i][j] += A[i][k] * B[k][j];
}
}
}
return result;
}
}
reference: Leetcode discussion
time O(n^3), space O(n^2)
如果题目给的input不是matrix,而是array,那么我们只需要sort array,然后loop就可以了,这样我们就不需要HashMap了
public class Solution {
public int[][] multiply(int[][] A, int[][] B) {
// check edge case
if (A == null || B == null) {
return new int[0][0];
}
// preprocess
int aRow = A.length;
int aCol = A[0].length;
int bRow = B.length;
int bCol = B[0].length;
int[][] result = new int[aRow][bCol];
List<int[]> vectorA = new ArrayList<int[]>();
List<int[]> vectorB = new ArrayList<int[]>();
for (int i = 0; i < aRow; i++) {
for (int k = 0; k < aCol; k++) {
if (A[i][k] != 0) {
vectorA.add(new int[]{i, k});
}
}
}
for (int k = 0; k < bRow; k++) {
for (int j = 0; j < bCol; j++) {
if (B[k][j] != 0) {
vectorB.add(new int[]{k, j});
}
}
}
// main loop
int i = 0;
int vectorAIndex = 0;
int vectorBIndex = 0;
while (i < aRow && vectorAIndex < vectorA.size()) {
vectorBIndex = 0;
for (; vectorAIndex < vectorA.size() && vectorA.get(vectorAIndex)[0] == i; vectorAIndex++) {
int k = vectorA.get(vectorAIndex)[1];
while (vectorBIndex < vectorB.size() && vectorB.get(vectorBIndex)[0] < k) {
vectorBIndex++;
}
for (; vectorBIndex < vectorB.size() && vectorB.get(vectorBIndex)[0] == k; vectorBIndex++) {
int j = vectorB.get(vectorBIndex)[1];
result[i][j] += A[i][k] * B[k][j];
}
}
i++;
}
return result;
}
}
reference: 一亩三分地