Leetcode 311. Sparse Matrix Multiplication

思路

这道题最暴力的解法其实就是我们解决矩阵相乘的正常解法。我们需要用到三个loop去遍历两个矩阵的每一个元素并且做乘法和加法。但是这道题的特殊条件是sparse matrix,意味着这个矩阵可能包含很多0。而0和其他数相乘还是0,所以在这种情况下我们根本不用进行计算,直接略过就好了。

第一种做法就是我们在遍历A[i][k]的时候,先去check A[i][k]是否为0,如果不是我们才去拿B当前col的每一位去计算。这样的好处是很多情况下节省了我们第三个for loop。Leetcode下这种解法比使用HashMap的解法要快,尽管理论上应该是HashMap更快。原因可能是HashMap操作的overhead。

第二种解法和第三种解法类似,就是使用一个或者两个HashMap,只存储所有不为0的entry信息。而且我们实际上只需要去存储row and col index就好了,因为我们原来的matrix还在,我们取值的时候还是可以去原来的matrix去取。

第四种解法就是使用sort instead of using HashMap。这里对于每个矩阵先根据row排序,再根据col排序。真正走loop的时候逻辑是一样的,只不过这里我们就得手动区分出每一行和每一列呢,不如HashMap写起来直接。

代码

以下代码是使用两个HashMap的解法,理论上是最优的解法,但是Big-O的角度worst case还是O(n^3),因为我们不知道有多sparse。

time O(n^3), space O(n^2)

  public class Solution {
    public int[][] multiply(int[][] A, int[][] B) {
        // check edge case
        if (A == null || B == null) {
            return new int[0][0];
        }
        // preprocess
        int aRow = A.length;
        int aCol = A[0].length;
        int bRow = B.length;
        int bCol = B[0].length;
        int[][] result = new int[aRow][bCol];

        HashMap<Integer, HashSet<Integer>> mapB = new HashMap<Integer, HashSet<Integer>>();
        for (int k = 0; k < bRow; k++) {
            HashSet<Integer> curSet = new HashSet<Integer>();
            for (int j = 0; j < bCol; j++) {
                if (B[k][j] != 0) {
                    curSet.add(j);
                }
            }
            mapB.put(k, curSet);
        }

        HashMap<Integer, HashSet<Integer>> mapA = new HashMap<Integer, HashSet<Integer>>();
        for (int i = 0; i < aRow; i++) {
            HashSet<Integer> curSet = new HashSet<Integer>();
            for (int k = 0; k < aCol; k++) {
                if (A[i][k] != 0) {
                    curSet.add(k);
                }
            }
            mapA.put(i, curSet);
        }
        // main loop
        for (int i = 0; i < aRow; i++) {
            for (int k : mapA.get(i)) {
                for (int j : mapB.get(k)) {
                    result[i][j] += A[i][k] * B[k][j];
                }
            }
        }
        return result;
    }
}

reference: Leetcode discussion

time O(n^3), space O(n^2)

如果题目给的input不是matrix,而是array,那么我们只需要sort array,然后loop就可以了,这样我们就不需要HashMap了

  public class Solution {
    public int[][] multiply(int[][] A, int[][] B) {
        // check edge case
        if (A == null || B == null) {
            return new int[0][0];
        }
        // preprocess
        int aRow = A.length;
        int aCol = A[0].length;
        int bRow = B.length;
        int bCol = B[0].length;
        int[][] result = new int[aRow][bCol];

        List<int[]> vectorA = new ArrayList<int[]>();
        List<int[]> vectorB = new ArrayList<int[]>();

        for (int i = 0; i < aRow; i++) {
            for (int k = 0; k < aCol; k++) {
                if (A[i][k] != 0) {
                    vectorA.add(new int[]{i, k});
                }
            }
        }

        for (int k = 0; k < bRow; k++) {
            for (int j = 0; j < bCol; j++) {
                if (B[k][j] != 0) {
                    vectorB.add(new int[]{k, j});
                }
            }
        }
        // main loop
        int i = 0;
        int vectorAIndex = 0;
        int vectorBIndex = 0;

        while (i < aRow && vectorAIndex < vectorA.size()) {
            vectorBIndex = 0;
            for (; vectorAIndex < vectorA.size() && vectorA.get(vectorAIndex)[0] == i; vectorAIndex++) {
                int k = vectorA.get(vectorAIndex)[1];
                while (vectorBIndex < vectorB.size() && vectorB.get(vectorBIndex)[0] < k) {
                    vectorBIndex++;
                }
                for (; vectorBIndex < vectorB.size() && vectorB.get(vectorBIndex)[0] == k; vectorBIndex++) {
                    int j = vectorB.get(vectorBIndex)[1];
                    result[i][j] += A[i][k] * B[k][j];
                }
            }
            i++;
        }
        return result;
    }
}

reference: 一亩三分地

results matching ""

    No results matching ""