/** * Created by Xan Gregg. Date: Jun 19, 2007 */ public class MultiplyTimer { public static void main(String[] args) { int N = 200; int M = 100; System.out.printf("ArrayOfArrays : %5.1f ms\n", timeMatrixArrayOfArrays(N, M)); System.out.printf("ArrayOfArrays : %5.1f ms\n", timeMatrixArrayOfArrays(N, M)); System.out.printf("ArrayOfArraysT: %5.1f ms\n", timeMatrixArrayOfArraysTranspose(N, M)); System.out.printf("ArrayOfArraysT: %5.1f ms\n", timeMatrixArrayOfArraysTranspose(N, M)); System.out.printf("Array : %5.1f ms\n", timeMatrixArray(N, M)); System.out.printf("Array : %5.1f ms\n", timeMatrixArray(N, M)); System.out.printf("ArrayT : %5.1f ms\n", timeMatrixArrayTranspose(N, M)); System.out.printf("ArrayT : %5.1f ms\n", timeMatrixArrayTranspose(N, M)); } private static double timeMatrixArrayOfArrays(int n, int runs) { double[][] a = new double[n][]; double[][] b = new double[n][]; double[][] c = new double[n][]; initMatrix(a, 3, 1, 300); initMatrix(b, 2, 7, 300); initMatrix(c, 0, 0, 100); double start = System.currentTimeMillis(); for (int i = 0; i < runs; i++) { multiplySquare(n, c, a, b); addSquare(n, c, a); addSquare(n, c, b); } double end = System.currentTimeMillis(); if (Math.abs(c[n-1][n-1] - 2487.5) > 1e10) throw new RuntimeException("wrong answer"); return (end - start) / runs; } private static double timeMatrixArrayOfArraysTranspose(int n, int runs) { double[][] a = new double[n][]; double[][] b = new double[n][]; double[][] c = new double[n][]; initMatrix(a, 3, 1, 300); initMatrix(b, 2, 7, 300); initMatrix(c, 0, 0, 100); double start = System.currentTimeMillis(); for (int i = 0; i < runs; i++) { multiplySquareT(n, c, a, b); addSquare(n, c, a); addSquare(n, c, b); } double end = System.currentTimeMillis(); if (Math.abs(c[n-1][n-1] - 2487.5) > 1e10) throw new RuntimeException("wrong answer"); return (end - start) / runs; } private static void addSquare(int n, double[][] c, double[][] a) { for (int i = 0; i < n; i++) { double[] ai = a[i]; double[] ci = c[i]; for (int j = 0; j < n; j++) { ci[j] += ai[j]; } } } private static void initMatrix(double[][] m, int a, int b, double scale) { for (int i = 0; i < m.length; i++) { m[i] = new double[m.length]; for (int j = 0; j < m[i].length; j++) { m[i][j] = (a * i + b * j) / scale; } } } public static void multiplySquare(int n, double[][] m3, double[][] m1, double[][] m2) { for (int i = 0; i < n; i++) { double m1i[] = m1[i]; double m3i[] = m3[i]; for (int j = 0; j < n; j++) { double val = 0; for (int k = 0; k < n; k++) { val += m1i[k] * m2[k][j]; } m3i[j] = val; } } } public static void multiplySquareT(int n, double[][] m3, double[][] m1, double[][] m2) { transposeSquare(m2); for (int i = 0; i < n; i++) { double m1i[] = m1[i]; double m3i[] = m3[i]; for (int j = 0; j < n; j++) { double m2j[] = m2[j]; double val = 0; for (int k = 0; k < n; k++) { val += m1i[k] * m2j[k]; } m3i[j] = val; } } transposeSquare(m2); } private static void transposeSquare(double[][] m) { for (int i = 0; i < m.length; i++) { for (int j = i + 1; j < m[i].length; j++) { double t = m[i][j]; m[i][j] = m[j][i]; m[j][i] = t; } } } private static double timeMatrixArray(int n, int runs) { double[] a = new double[n * n]; double[] b = new double[n * n]; double[] c = new double[n * n]; initMatrix(n, a, 3, 1, 300); initMatrix(n, b, 2, 7, 300); initMatrix(n, c, 0, 0, 100); double start = System.currentTimeMillis(); for (int i = 0; i < runs; i++) { multiplySquare(n, c, a, b); addSquare(n, c, a); addSquare(n, c, b); } double end = System.currentTimeMillis(); if (Math.abs(c[n*n-1] - 2487.5) > 1e10) throw new RuntimeException("wrong answer"); return (end - start) / runs; } private static double timeMatrixArrayTranspose(int n, int runs) { double[] a = new double[n * n]; double[] b = new double[n * n]; double[] c = new double[n * n]; initMatrix(n, a, 3, 1, 300); initMatrix(n, b, 2, 7, 300); initMatrix(n, c, 0, 0, 100); double start = System.currentTimeMillis(); for (int i = 0; i < runs; i++) { multiplySquareT(n, c, a, b); addSquare(n, c, a); addSquare(n, c, b); } double end = System.currentTimeMillis(); if (Math.abs(c[n*n-1] - 2487.5) > 1e10) throw new RuntimeException("wrong answer"); return (end - start) / runs; } private static void initMatrix(int n, double[] m, int a, int b, double scale) { for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { m[i * n + j] = (a * i + b * j) / scale; } } } private static void addSquare(int n, double[] c, double[] a) { int nn = n * n; for (int i = 0; i < nn; i++) { c[i] += a[i]; } } public static void multiplySquare(int n, double[] m3, double[] m1, double[] m2) { int in = 0; for (int i = 0; i < n; i++, in += n) { for (int j = 0; j < n; j++) { double val = 0; int kn = 0; for (int k = 0; k < n; k++, kn += n) { val += m1[in + k] * m2[kn + j]; } m3[in + j] = val; } } } public static void multiplySquareT(int n, double[] m3, double[] m1, double[] m2) { transposeSquare(n, m2); int in = 0; for (int i = 0; i < n; i++, in += n) { int jn = 0; for (int j = 0; j < n; j++, jn += n) { double val = 0; for (int k = 0; k < n; k++) { val += m1[in + k] * m2[jn + k]; } m3[in + j] = val; } } transposeSquare(n, m2); } private static void transposeSquare(int n, double[] m) { int in = 0; for (int i = 0; i < n; i++, in += n) { int jn = 0; for (int j = i + 1; j < n; j++, jn += n) { double t = m[in + j]; m[in + j] = m[jn + i]; m[jn + i] = t; } } } }