xref: /aosp_15_r20/external/apache-commons-math/src/main/java/org/apache/commons/math3/linear/MatrixUtils.java (revision d3fac44428dd0296a04a50c6827e3205b8dbea8a)
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.commons.math3.linear;
19 
20 import org.apache.commons.math3.Field;
21 import org.apache.commons.math3.FieldElement;
22 import org.apache.commons.math3.exception.DimensionMismatchException;
23 import org.apache.commons.math3.exception.MathArithmeticException;
24 import org.apache.commons.math3.exception.NoDataException;
25 import org.apache.commons.math3.exception.NullArgumentException;
26 import org.apache.commons.math3.exception.NumberIsTooSmallException;
27 import org.apache.commons.math3.exception.OutOfRangeException;
28 import org.apache.commons.math3.exception.ZeroException;
29 import org.apache.commons.math3.exception.util.LocalizedFormats;
30 import org.apache.commons.math3.fraction.BigFraction;
31 import org.apache.commons.math3.fraction.Fraction;
32 import org.apache.commons.math3.util.FastMath;
33 import org.apache.commons.math3.util.MathArrays;
34 import org.apache.commons.math3.util.MathUtils;
35 import org.apache.commons.math3.util.Precision;
36 
37 import java.io.IOException;
38 import java.io.ObjectInputStream;
39 import java.io.ObjectOutputStream;
40 import java.util.Arrays;
41 
42 /** A collection of static methods that operate on or return matrices. */
43 public class MatrixUtils {
44 
45     /**
46      * The default format for {@link RealMatrix} objects.
47      *
48      * @since 3.1
49      */
50     public static final RealMatrixFormat DEFAULT_FORMAT = RealMatrixFormat.getInstance();
51 
52     /**
53      * A format for {@link RealMatrix} objects compatible with octave.
54      *
55      * @since 3.1
56      */
57     public static final RealMatrixFormat OCTAVE_FORMAT =
58             new RealMatrixFormat("[", "]", "", "", "; ", ", ");
59 
60     /** Private constructor. */
MatrixUtils()61     private MatrixUtils() {
62         super();
63     }
64 
65     /**
66      * Returns a {@link RealMatrix} with specified dimensions.
67      *
68      * <p>The type of matrix returned depends on the dimension. Below 2<sup>12</sup> elements (i.e.
69      * 4096 elements or 64&times;64 for a square matrix) which can be stored in a 32kB array, a
70      * {@link Array2DRowRealMatrix} instance is built. Above this threshold a {@link
71      * BlockRealMatrix} instance is built.
72      *
73      * <p>The matrix elements are all set to 0.0.
74      *
75      * @param rows number of rows of the matrix
76      * @param columns number of columns of the matrix
77      * @return RealMatrix with specified dimensions
78      * @see #createRealMatrix(double[][])
79      */
createRealMatrix(final int rows, final int columns)80     public static RealMatrix createRealMatrix(final int rows, final int columns) {
81         return (rows * columns <= 4096)
82                 ? new Array2DRowRealMatrix(rows, columns)
83                 : new BlockRealMatrix(rows, columns);
84     }
85 
86     /**
87      * Returns a {@link FieldMatrix} with specified dimensions.
88      *
89      * <p>The type of matrix returned depends on the dimension. Below 2<sup>12</sup> elements (i.e.
90      * 4096 elements or 64&times;64 for a square matrix), a {@link FieldMatrix} instance is built.
91      * Above this threshold a {@link BlockFieldMatrix} instance is built.
92      *
93      * <p>The matrix elements are all set to field.getZero().
94      *
95      * @param <T> the type of the field elements
96      * @param field field to which the matrix elements belong
97      * @param rows number of rows of the matrix
98      * @param columns number of columns of the matrix
99      * @return FieldMatrix with specified dimensions
100      * @see #createFieldMatrix(FieldElement[][])
101      * @since 2.0
102      */
createFieldMatrix( final Field<T> field, final int rows, final int columns)103     public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix(
104             final Field<T> field, final int rows, final int columns) {
105         return (rows * columns <= 4096)
106                 ? new Array2DRowFieldMatrix<T>(field, rows, columns)
107                 : new BlockFieldMatrix<T>(field, rows, columns);
108     }
109 
110     /**
111      * Returns a {@link RealMatrix} whose entries are the the values in the the input array.
112      *
113      * <p>The type of matrix returned depends on the dimension. Below 2<sup>12</sup> elements (i.e.
114      * 4096 elements or 64&times;64 for a square matrix) which can be stored in a 32kB array, a
115      * {@link Array2DRowRealMatrix} instance is built. Above this threshold a {@link
116      * BlockRealMatrix} instance is built.
117      *
118      * <p>The input array is copied, not referenced.
119      *
120      * @param data input array
121      * @return RealMatrix containing the values of the array
122      * @throws org.apache.commons.math3.exception.DimensionMismatchException if {@code data} is not
123      *     rectangular (not all rows have the same length).
124      * @throws NoDataException if a row or column is empty.
125      * @throws NullArgumentException if either {@code data} or {@code data[0]} is {@code null}.
126      * @throws DimensionMismatchException if {@code data} is not rectangular.
127      * @see #createRealMatrix(int, int)
128      */
createRealMatrix(double[][] data)129     public static RealMatrix createRealMatrix(double[][] data)
130             throws NullArgumentException, DimensionMismatchException, NoDataException {
131         if (data == null || data[0] == null) {
132             throw new NullArgumentException();
133         }
134         return (data.length * data[0].length <= 4096)
135                 ? new Array2DRowRealMatrix(data)
136                 : new BlockRealMatrix(data);
137     }
138 
139     /**
140      * Returns a {@link FieldMatrix} whose entries are the the values in the the input array.
141      *
142      * <p>The type of matrix returned depends on the dimension. Below 2<sup>12</sup> elements (i.e.
143      * 4096 elements or 64&times;64 for a square matrix), a {@link FieldMatrix} instance is built.
144      * Above this threshold a {@link BlockFieldMatrix} instance is built.
145      *
146      * <p>The input array is copied, not referenced.
147      *
148      * @param <T> the type of the field elements
149      * @param data input array
150      * @return a matrix containing the values of the array.
151      * @throws org.apache.commons.math3.exception.DimensionMismatchException if {@code data} is not
152      *     rectangular (not all rows have the same length).
153      * @throws NoDataException if a row or column is empty.
154      * @throws NullArgumentException if either {@code data} or {@code data[0]} is {@code null}.
155      * @see #createFieldMatrix(Field, int, int)
156      * @since 2.0
157      */
createFieldMatrix(T[][] data)158     public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix(T[][] data)
159             throws DimensionMismatchException, NoDataException, NullArgumentException {
160         if (data == null || data[0] == null) {
161             throw new NullArgumentException();
162         }
163         return (data.length * data[0].length <= 4096)
164                 ? new Array2DRowFieldMatrix<T>(data)
165                 : new BlockFieldMatrix<T>(data);
166     }
167 
168     /**
169      * Returns <code>dimension x dimension</code> identity matrix.
170      *
171      * @param dimension dimension of identity matrix to generate
172      * @return identity matrix
173      * @throws IllegalArgumentException if dimension is not positive
174      * @since 1.1
175      */
createRealIdentityMatrix(int dimension)176     public static RealMatrix createRealIdentityMatrix(int dimension) {
177         final RealMatrix m = createRealMatrix(dimension, dimension);
178         for (int i = 0; i < dimension; ++i) {
179             m.setEntry(i, i, 1.0);
180         }
181         return m;
182     }
183 
184     /**
185      * Returns <code>dimension x dimension</code> identity matrix.
186      *
187      * @param <T> the type of the field elements
188      * @param field field to which the elements belong
189      * @param dimension dimension of identity matrix to generate
190      * @return identity matrix
191      * @throws IllegalArgumentException if dimension is not positive
192      * @since 2.0
193      */
createFieldIdentityMatrix( final Field<T> field, final int dimension)194     public static <T extends FieldElement<T>> FieldMatrix<T> createFieldIdentityMatrix(
195             final Field<T> field, final int dimension) {
196         final T zero = field.getZero();
197         final T one = field.getOne();
198         final T[][] d = MathArrays.buildArray(field, dimension, dimension);
199         for (int row = 0; row < dimension; row++) {
200             final T[] dRow = d[row];
201             Arrays.fill(dRow, zero);
202             dRow[row] = one;
203         }
204         return new Array2DRowFieldMatrix<T>(field, d, false);
205     }
206 
207     /**
208      * Returns a diagonal matrix with specified elements.
209      *
210      * @param diagonal diagonal elements of the matrix (the array elements will be copied)
211      * @return diagonal matrix
212      * @since 2.0
213      */
createRealDiagonalMatrix(final double[] diagonal)214     public static RealMatrix createRealDiagonalMatrix(final double[] diagonal) {
215         final RealMatrix m = createRealMatrix(diagonal.length, diagonal.length);
216         for (int i = 0; i < diagonal.length; ++i) {
217             m.setEntry(i, i, diagonal[i]);
218         }
219         return m;
220     }
221 
222     /**
223      * Returns a diagonal matrix with specified elements.
224      *
225      * @param <T> the type of the field elements
226      * @param diagonal diagonal elements of the matrix (the array elements will be copied)
227      * @return diagonal matrix
228      * @since 2.0
229      */
createFieldDiagonalMatrix( final T[] diagonal)230     public static <T extends FieldElement<T>> FieldMatrix<T> createFieldDiagonalMatrix(
231             final T[] diagonal) {
232         final FieldMatrix<T> m =
233                 createFieldMatrix(diagonal[0].getField(), diagonal.length, diagonal.length);
234         for (int i = 0; i < diagonal.length; ++i) {
235             m.setEntry(i, i, diagonal[i]);
236         }
237         return m;
238     }
239 
240     /**
241      * Creates a {@link RealVector} using the data from the input array.
242      *
243      * @param data the input data
244      * @return a data.length RealVector
245      * @throws NoDataException if {@code data} is empty.
246      * @throws NullArgumentException if {@code data} is {@code null}.
247      */
createRealVector(double[] data)248     public static RealVector createRealVector(double[] data)
249             throws NoDataException, NullArgumentException {
250         if (data == null) {
251             throw new NullArgumentException();
252         }
253         return new ArrayRealVector(data, true);
254     }
255 
256     /**
257      * Creates a {@link FieldVector} using the data from the input array.
258      *
259      * @param <T> the type of the field elements
260      * @param data the input data
261      * @return a data.length FieldVector
262      * @throws NoDataException if {@code data} is empty.
263      * @throws NullArgumentException if {@code data} is {@code null}.
264      * @throws ZeroException if {@code data} has 0 elements
265      */
createFieldVector(final T[] data)266     public static <T extends FieldElement<T>> FieldVector<T> createFieldVector(final T[] data)
267             throws NoDataException, NullArgumentException, ZeroException {
268         if (data == null) {
269             throw new NullArgumentException();
270         }
271         if (data.length == 0) {
272             throw new ZeroException(LocalizedFormats.VECTOR_MUST_HAVE_AT_LEAST_ONE_ELEMENT);
273         }
274         return new ArrayFieldVector<T>(data[0].getField(), data, true);
275     }
276 
277     /**
278      * Create a row {@link RealMatrix} using the data from the input array.
279      *
280      * @param rowData the input row data
281      * @return a 1 x rowData.length RealMatrix
282      * @throws NoDataException if {@code rowData} is empty.
283      * @throws NullArgumentException if {@code rowData} is {@code null}.
284      */
createRowRealMatrix(double[] rowData)285     public static RealMatrix createRowRealMatrix(double[] rowData)
286             throws NoDataException, NullArgumentException {
287         if (rowData == null) {
288             throw new NullArgumentException();
289         }
290         final int nCols = rowData.length;
291         final RealMatrix m = createRealMatrix(1, nCols);
292         for (int i = 0; i < nCols; ++i) {
293             m.setEntry(0, i, rowData[i]);
294         }
295         return m;
296     }
297 
298     /**
299      * Create a row {@link FieldMatrix} using the data from the input array.
300      *
301      * @param <T> the type of the field elements
302      * @param rowData the input row data
303      * @return a 1 x rowData.length FieldMatrix
304      * @throws NoDataException if {@code rowData} is empty.
305      * @throws NullArgumentException if {@code rowData} is {@code null}.
306      */
createRowFieldMatrix(final T[] rowData)307     public static <T extends FieldElement<T>> FieldMatrix<T> createRowFieldMatrix(final T[] rowData)
308             throws NoDataException, NullArgumentException {
309         if (rowData == null) {
310             throw new NullArgumentException();
311         }
312         final int nCols = rowData.length;
313         if (nCols == 0) {
314             throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_COLUMN);
315         }
316         final FieldMatrix<T> m = createFieldMatrix(rowData[0].getField(), 1, nCols);
317         for (int i = 0; i < nCols; ++i) {
318             m.setEntry(0, i, rowData[i]);
319         }
320         return m;
321     }
322 
323     /**
324      * Creates a column {@link RealMatrix} using the data from the input array.
325      *
326      * @param columnData the input column data
327      * @return a columnData x 1 RealMatrix
328      * @throws NoDataException if {@code columnData} is empty.
329      * @throws NullArgumentException if {@code columnData} is {@code null}.
330      */
createColumnRealMatrix(double[] columnData)331     public static RealMatrix createColumnRealMatrix(double[] columnData)
332             throws NoDataException, NullArgumentException {
333         if (columnData == null) {
334             throw new NullArgumentException();
335         }
336         final int nRows = columnData.length;
337         final RealMatrix m = createRealMatrix(nRows, 1);
338         for (int i = 0; i < nRows; ++i) {
339             m.setEntry(i, 0, columnData[i]);
340         }
341         return m;
342     }
343 
344     /**
345      * Creates a column {@link FieldMatrix} using the data from the input array.
346      *
347      * @param <T> the type of the field elements
348      * @param columnData the input column data
349      * @return a columnData x 1 FieldMatrix
350      * @throws NoDataException if {@code data} is empty.
351      * @throws NullArgumentException if {@code columnData} is {@code null}.
352      */
createColumnFieldMatrix( final T[] columnData)353     public static <T extends FieldElement<T>> FieldMatrix<T> createColumnFieldMatrix(
354             final T[] columnData) throws NoDataException, NullArgumentException {
355         if (columnData == null) {
356             throw new NullArgumentException();
357         }
358         final int nRows = columnData.length;
359         if (nRows == 0) {
360             throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_ROW);
361         }
362         final FieldMatrix<T> m = createFieldMatrix(columnData[0].getField(), nRows, 1);
363         for (int i = 0; i < nRows; ++i) {
364             m.setEntry(i, 0, columnData[i]);
365         }
366         return m;
367     }
368 
369     /**
370      * Checks whether a matrix is symmetric, within a given relative tolerance.
371      *
372      * @param matrix Matrix to check.
373      * @param relativeTolerance Tolerance of the symmetry check.
374      * @param raiseException If {@code true}, an exception will be raised if the matrix is not
375      *     symmetric.
376      * @return {@code true} if {@code matrix} is symmetric.
377      * @throws NonSquareMatrixException if the matrix is not square.
378      * @throws NonSymmetricMatrixException if the matrix is not symmetric.
379      */
isSymmetricInternal( RealMatrix matrix, double relativeTolerance, boolean raiseException)380     private static boolean isSymmetricInternal(
381             RealMatrix matrix, double relativeTolerance, boolean raiseException) {
382         final int rows = matrix.getRowDimension();
383         if (rows != matrix.getColumnDimension()) {
384             if (raiseException) {
385                 throw new NonSquareMatrixException(rows, matrix.getColumnDimension());
386             } else {
387                 return false;
388             }
389         }
390         for (int i = 0; i < rows; i++) {
391             for (int j = i + 1; j < rows; j++) {
392                 final double mij = matrix.getEntry(i, j);
393                 final double mji = matrix.getEntry(j, i);
394                 if (FastMath.abs(mij - mji)
395                         > FastMath.max(FastMath.abs(mij), FastMath.abs(mji)) * relativeTolerance) {
396                     if (raiseException) {
397                         throw new NonSymmetricMatrixException(i, j, relativeTolerance);
398                     } else {
399                         return false;
400                     }
401                 }
402             }
403         }
404         return true;
405     }
406 
407     /**
408      * Checks whether a matrix is symmetric.
409      *
410      * @param matrix Matrix to check.
411      * @param eps Relative tolerance.
412      * @throws NonSquareMatrixException if the matrix is not square.
413      * @throws NonSymmetricMatrixException if the matrix is not symmetric.
414      * @since 3.1
415      */
checkSymmetric(RealMatrix matrix, double eps)416     public static void checkSymmetric(RealMatrix matrix, double eps) {
417         isSymmetricInternal(matrix, eps, true);
418     }
419 
420     /**
421      * Checks whether a matrix is symmetric.
422      *
423      * @param matrix Matrix to check.
424      * @param eps Relative tolerance.
425      * @return {@code true} if {@code matrix} is symmetric.
426      * @since 3.1
427      */
isSymmetric(RealMatrix matrix, double eps)428     public static boolean isSymmetric(RealMatrix matrix, double eps) {
429         return isSymmetricInternal(matrix, eps, false);
430     }
431 
432     /**
433      * Check if matrix indices are valid.
434      *
435      * @param m Matrix.
436      * @param row Row index to check.
437      * @param column Column index to check.
438      * @throws OutOfRangeException if {@code row} or {@code column} is not a valid index.
439      */
checkMatrixIndex(final AnyMatrix m, final int row, final int column)440     public static void checkMatrixIndex(final AnyMatrix m, final int row, final int column)
441             throws OutOfRangeException {
442         checkRowIndex(m, row);
443         checkColumnIndex(m, column);
444     }
445 
446     /**
447      * Check if a row index is valid.
448      *
449      * @param m Matrix.
450      * @param row Row index to check.
451      * @throws OutOfRangeException if {@code row} is not a valid index.
452      */
checkRowIndex(final AnyMatrix m, final int row)453     public static void checkRowIndex(final AnyMatrix m, final int row) throws OutOfRangeException {
454         if (row < 0 || row >= m.getRowDimension()) {
455             throw new OutOfRangeException(
456                     LocalizedFormats.ROW_INDEX, row, 0, m.getRowDimension() - 1);
457         }
458     }
459 
460     /**
461      * Check if a column index is valid.
462      *
463      * @param m Matrix.
464      * @param column Column index to check.
465      * @throws OutOfRangeException if {@code column} is not a valid index.
466      */
checkColumnIndex(final AnyMatrix m, final int column)467     public static void checkColumnIndex(final AnyMatrix m, final int column)
468             throws OutOfRangeException {
469         if (column < 0 || column >= m.getColumnDimension()) {
470             throw new OutOfRangeException(
471                     LocalizedFormats.COLUMN_INDEX, column, 0, m.getColumnDimension() - 1);
472         }
473     }
474 
475     /**
476      * Check if submatrix ranges indices are valid. Rows and columns are indicated counting from 0
477      * to {@code n - 1}.
478      *
479      * @param m Matrix.
480      * @param startRow Initial row index.
481      * @param endRow Final row index.
482      * @param startColumn Initial column index.
483      * @param endColumn Final column index.
484      * @throws OutOfRangeException if the indices are invalid.
485      * @throws NumberIsTooSmallException if {@code endRow < startRow} or {@code endColumn <
486      *     startColumn}.
487      */
checkSubMatrixIndex( final AnyMatrix m, final int startRow, final int endRow, final int startColumn, final int endColumn)488     public static void checkSubMatrixIndex(
489             final AnyMatrix m,
490             final int startRow,
491             final int endRow,
492             final int startColumn,
493             final int endColumn)
494             throws NumberIsTooSmallException, OutOfRangeException {
495         checkRowIndex(m, startRow);
496         checkRowIndex(m, endRow);
497         if (endRow < startRow) {
498             throw new NumberIsTooSmallException(
499                     LocalizedFormats.INITIAL_ROW_AFTER_FINAL_ROW, endRow, startRow, false);
500         }
501 
502         checkColumnIndex(m, startColumn);
503         checkColumnIndex(m, endColumn);
504         if (endColumn < startColumn) {
505             throw new NumberIsTooSmallException(
506                     LocalizedFormats.INITIAL_COLUMN_AFTER_FINAL_COLUMN,
507                     endColumn,
508                     startColumn,
509                     false);
510         }
511     }
512 
513     /**
514      * Check if submatrix ranges indices are valid. Rows and columns are indicated counting from 0
515      * to n-1.
516      *
517      * @param m Matrix.
518      * @param selectedRows Array of row indices.
519      * @param selectedColumns Array of column indices.
520      * @throws NullArgumentException if {@code selectedRows} or {@code selectedColumns} are {@code
521      *     null}.
522      * @throws NoDataException if the row or column selections are empty (zero length).
523      * @throws OutOfRangeException if row or column selections are not valid.
524      */
checkSubMatrixIndex( final AnyMatrix m, final int[] selectedRows, final int[] selectedColumns)525     public static void checkSubMatrixIndex(
526             final AnyMatrix m, final int[] selectedRows, final int[] selectedColumns)
527             throws NoDataException, NullArgumentException, OutOfRangeException {
528         if (selectedRows == null) {
529             throw new NullArgumentException();
530         }
531         if (selectedColumns == null) {
532             throw new NullArgumentException();
533         }
534         if (selectedRows.length == 0) {
535             throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_ROW_INDEX_ARRAY);
536         }
537         if (selectedColumns.length == 0) {
538             throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_COLUMN_INDEX_ARRAY);
539         }
540 
541         for (final int row : selectedRows) {
542             checkRowIndex(m, row);
543         }
544         for (final int column : selectedColumns) {
545             checkColumnIndex(m, column);
546         }
547     }
548 
549     /**
550      * Check if matrices are addition compatible.
551      *
552      * @param left Left hand side matrix.
553      * @param right Right hand side matrix.
554      * @throws MatrixDimensionMismatchException if the matrices are not addition compatible.
555      */
checkAdditionCompatible(final AnyMatrix left, final AnyMatrix right)556     public static void checkAdditionCompatible(final AnyMatrix left, final AnyMatrix right)
557             throws MatrixDimensionMismatchException {
558         if ((left.getRowDimension() != right.getRowDimension())
559                 || (left.getColumnDimension() != right.getColumnDimension())) {
560             throw new MatrixDimensionMismatchException(
561                     left.getRowDimension(), left.getColumnDimension(),
562                     right.getRowDimension(), right.getColumnDimension());
563         }
564     }
565 
566     /**
567      * Check if matrices are subtraction compatible
568      *
569      * @param left Left hand side matrix.
570      * @param right Right hand side matrix.
571      * @throws MatrixDimensionMismatchException if the matrices are not addition compatible.
572      */
checkSubtractionCompatible(final AnyMatrix left, final AnyMatrix right)573     public static void checkSubtractionCompatible(final AnyMatrix left, final AnyMatrix right)
574             throws MatrixDimensionMismatchException {
575         if ((left.getRowDimension() != right.getRowDimension())
576                 || (left.getColumnDimension() != right.getColumnDimension())) {
577             throw new MatrixDimensionMismatchException(
578                     left.getRowDimension(), left.getColumnDimension(),
579                     right.getRowDimension(), right.getColumnDimension());
580         }
581     }
582 
583     /**
584      * Check if matrices are multiplication compatible
585      *
586      * @param left Left hand side matrix.
587      * @param right Right hand side matrix.
588      * @throws DimensionMismatchException if matrices are not multiplication compatible.
589      */
checkMultiplicationCompatible(final AnyMatrix left, final AnyMatrix right)590     public static void checkMultiplicationCompatible(final AnyMatrix left, final AnyMatrix right)
591             throws DimensionMismatchException {
592 
593         if (left.getColumnDimension() != right.getRowDimension()) {
594             throw new DimensionMismatchException(
595                     left.getColumnDimension(), right.getRowDimension());
596         }
597     }
598 
599     /**
600      * Convert a {@link FieldMatrix}/{@link Fraction} matrix to a {@link RealMatrix}.
601      *
602      * @param m Matrix to convert.
603      * @return the converted matrix.
604      */
fractionMatrixToRealMatrix(final FieldMatrix<Fraction> m)605     public static Array2DRowRealMatrix fractionMatrixToRealMatrix(final FieldMatrix<Fraction> m) {
606         final FractionMatrixConverter converter = new FractionMatrixConverter();
607         m.walkInOptimizedOrder(converter);
608         return converter.getConvertedMatrix();
609     }
610 
611     /** Converter for {@link FieldMatrix}/{@link Fraction}. */
612     private static class FractionMatrixConverter
613             extends DefaultFieldMatrixPreservingVisitor<Fraction> {
614         /** Converted array. */
615         private double[][] data;
616 
617         /** Simple constructor. */
FractionMatrixConverter()618         FractionMatrixConverter() {
619             super(Fraction.ZERO);
620         }
621 
622         /** {@inheritDoc} */
623         @Override
start( int rows, int columns, int startRow, int endRow, int startColumn, int endColumn)624         public void start(
625                 int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
626             data = new double[rows][columns];
627         }
628 
629         /** {@inheritDoc} */
630         @Override
visit(int row, int column, Fraction value)631         public void visit(int row, int column, Fraction value) {
632             data[row][column] = value.doubleValue();
633         }
634 
635         /**
636          * Get the converted matrix.
637          *
638          * @return the converted matrix.
639          */
getConvertedMatrix()640         Array2DRowRealMatrix getConvertedMatrix() {
641             return new Array2DRowRealMatrix(data, false);
642         }
643     }
644 
645     /**
646      * Convert a {@link FieldMatrix}/{@link BigFraction} matrix to a {@link RealMatrix}.
647      *
648      * @param m Matrix to convert.
649      * @return the converted matrix.
650      */
bigFractionMatrixToRealMatrix( final FieldMatrix<BigFraction> m)651     public static Array2DRowRealMatrix bigFractionMatrixToRealMatrix(
652             final FieldMatrix<BigFraction> m) {
653         final BigFractionMatrixConverter converter = new BigFractionMatrixConverter();
654         m.walkInOptimizedOrder(converter);
655         return converter.getConvertedMatrix();
656     }
657 
658     /** Converter for {@link FieldMatrix}/{@link BigFraction}. */
659     private static class BigFractionMatrixConverter
660             extends DefaultFieldMatrixPreservingVisitor<BigFraction> {
661         /** Converted array. */
662         private double[][] data;
663 
664         /** Simple constructor. */
BigFractionMatrixConverter()665         BigFractionMatrixConverter() {
666             super(BigFraction.ZERO);
667         }
668 
669         /** {@inheritDoc} */
670         @Override
start( int rows, int columns, int startRow, int endRow, int startColumn, int endColumn)671         public void start(
672                 int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
673             data = new double[rows][columns];
674         }
675 
676         /** {@inheritDoc} */
677         @Override
visit(int row, int column, BigFraction value)678         public void visit(int row, int column, BigFraction value) {
679             data[row][column] = value.doubleValue();
680         }
681 
682         /**
683          * Get the converted matrix.
684          *
685          * @return the converted matrix.
686          */
getConvertedMatrix()687         Array2DRowRealMatrix getConvertedMatrix() {
688             return new Array2DRowRealMatrix(data, false);
689         }
690     }
691 
692     /**
693      * Serialize a {@link RealVector}.
694      *
695      * <p>This method is intended to be called from within a private <code>writeObject</code> method
696      * (after a call to <code>oos.defaultWriteObject()</code>) in a class that has a {@link
697      * RealVector} field, which should be declared <code>transient</code>. This way, the default
698      * handling does not serialize the vector (the {@link RealVector} interface is not serializable
699      * by default) but this method does serialize it specifically.
700      *
701      * <p>The following example shows how a simple class with a name and a real vector should be
702      * written:
703      *
704      * <pre><code>
705      * public class NamedVector implements Serializable {
706      *
707      *     private final String name;
708      *     private final transient RealVector coefficients;
709      *
710      *     // omitted constructors, getters ...
711      *
712      *     private void writeObject(ObjectOutputStream oos) throws IOException {
713      *         oos.defaultWriteObject();  // takes care of name field
714      *         MatrixUtils.serializeRealVector(coefficients, oos);
715      *     }
716      *
717      *     private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
718      *         ois.defaultReadObject();  // takes care of name field
719      *         MatrixUtils.deserializeRealVector(this, "coefficients", ois);
720      *     }
721      *
722      * }
723      * </code></pre>
724      *
725      * @param vector real vector to serialize
726      * @param oos stream where the real vector should be written
727      * @exception IOException if object cannot be written to stream
728      * @see #deserializeRealVector(Object, String, ObjectInputStream)
729      */
serializeRealVector(final RealVector vector, final ObjectOutputStream oos)730     public static void serializeRealVector(final RealVector vector, final ObjectOutputStream oos)
731             throws IOException {
732         final int n = vector.getDimension();
733         oos.writeInt(n);
734         for (int i = 0; i < n; ++i) {
735             oos.writeDouble(vector.getEntry(i));
736         }
737     }
738 
739     /**
740      * Deserialize a {@link RealVector} field in a class.
741      *
742      * <p>This method is intended to be called from within a private <code>readObject</code> method
743      * (after a call to <code>ois.defaultReadObject()</code>) in a class that has a {@link
744      * RealVector} field, which should be declared <code>transient</code>. This way, the default
745      * handling does not deserialize the vector (the {@link RealVector} interface is not
746      * serializable by default) but this method does deserialize it specifically.
747      *
748      * @param instance instance in which the field must be set up
749      * @param fieldName name of the field within the class (may be private and final)
750      * @param ois stream from which the real vector should be read
751      * @exception ClassNotFoundException if a class in the stream cannot be found
752      * @exception IOException if object cannot be read from the stream
753      * @see #serializeRealVector(RealVector, ObjectOutputStream)
754      */
deserializeRealVector( final Object instance, final String fieldName, final ObjectInputStream ois)755     public static void deserializeRealVector(
756             final Object instance, final String fieldName, final ObjectInputStream ois)
757             throws ClassNotFoundException, IOException {
758         try {
759 
760             // read the vector data
761             final int n = ois.readInt();
762             final double[] data = new double[n];
763             for (int i = 0; i < n; ++i) {
764                 data[i] = ois.readDouble();
765             }
766 
767             // create the instance
768             final RealVector vector = new ArrayRealVector(data, false);
769 
770             // set up the field
771             final java.lang.reflect.Field f = instance.getClass().getDeclaredField(fieldName);
772             f.setAccessible(true);
773             f.set(instance, vector);
774 
775         } catch (NoSuchFieldException nsfe) {
776             IOException ioe = new IOException();
777             ioe.initCause(nsfe);
778             throw ioe;
779         } catch (IllegalAccessException iae) {
780             IOException ioe = new IOException();
781             ioe.initCause(iae);
782             throw ioe;
783         }
784     }
785 
786     /**
787      * Serialize a {@link RealMatrix}.
788      *
789      * <p>This method is intended to be called from within a private <code>writeObject</code> method
790      * (after a call to <code>oos.defaultWriteObject()</code>) in a class that has a {@link
791      * RealMatrix} field, which should be declared <code>transient</code>. This way, the default
792      * handling does not serialize the matrix (the {@link RealMatrix} interface is not serializable
793      * by default) but this method does serialize it specifically.
794      *
795      * <p>The following example shows how a simple class with a name and a real matrix should be
796      * written:
797      *
798      * <pre><code>
799      * public class NamedMatrix implements Serializable {
800      *
801      *     private final String name;
802      *     private final transient RealMatrix coefficients;
803      *
804      *     // omitted constructors, getters ...
805      *
806      *     private void writeObject(ObjectOutputStream oos) throws IOException {
807      *         oos.defaultWriteObject();  // takes care of name field
808      *         MatrixUtils.serializeRealMatrix(coefficients, oos);
809      *     }
810      *
811      *     private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
812      *         ois.defaultReadObject();  // takes care of name field
813      *         MatrixUtils.deserializeRealMatrix(this, "coefficients", ois);
814      *     }
815      *
816      * }
817      * </code></pre>
818      *
819      * @param matrix real matrix to serialize
820      * @param oos stream where the real matrix should be written
821      * @exception IOException if object cannot be written to stream
822      * @see #deserializeRealMatrix(Object, String, ObjectInputStream)
823      */
serializeRealMatrix(final RealMatrix matrix, final ObjectOutputStream oos)824     public static void serializeRealMatrix(final RealMatrix matrix, final ObjectOutputStream oos)
825             throws IOException {
826         final int n = matrix.getRowDimension();
827         final int m = matrix.getColumnDimension();
828         oos.writeInt(n);
829         oos.writeInt(m);
830         for (int i = 0; i < n; ++i) {
831             for (int j = 0; j < m; ++j) {
832                 oos.writeDouble(matrix.getEntry(i, j));
833             }
834         }
835     }
836 
837     /**
838      * Deserialize a {@link RealMatrix} field in a class.
839      *
840      * <p>This method is intended to be called from within a private <code>readObject</code> method
841      * (after a call to <code>ois.defaultReadObject()</code>) in a class that has a {@link
842      * RealMatrix} field, which should be declared <code>transient</code>. This way, the default
843      * handling does not deserialize the matrix (the {@link RealMatrix} interface is not
844      * serializable by default) but this method does deserialize it specifically.
845      *
846      * @param instance instance in which the field must be set up
847      * @param fieldName name of the field within the class (may be private and final)
848      * @param ois stream from which the real matrix should be read
849      * @exception ClassNotFoundException if a class in the stream cannot be found
850      * @exception IOException if object cannot be read from the stream
851      * @see #serializeRealMatrix(RealMatrix, ObjectOutputStream)
852      */
deserializeRealMatrix( final Object instance, final String fieldName, final ObjectInputStream ois)853     public static void deserializeRealMatrix(
854             final Object instance, final String fieldName, final ObjectInputStream ois)
855             throws ClassNotFoundException, IOException {
856         try {
857 
858             // read the matrix data
859             final int n = ois.readInt();
860             final int m = ois.readInt();
861             final double[][] data = new double[n][m];
862             for (int i = 0; i < n; ++i) {
863                 final double[] dataI = data[i];
864                 for (int j = 0; j < m; ++j) {
865                     dataI[j] = ois.readDouble();
866                 }
867             }
868 
869             // create the instance
870             final RealMatrix matrix = new Array2DRowRealMatrix(data, false);
871 
872             // set up the field
873             final java.lang.reflect.Field f = instance.getClass().getDeclaredField(fieldName);
874             f.setAccessible(true);
875             f.set(instance, matrix);
876 
877         } catch (NoSuchFieldException nsfe) {
878             IOException ioe = new IOException();
879             ioe.initCause(nsfe);
880             throw ioe;
881         } catch (IllegalAccessException iae) {
882             IOException ioe = new IOException();
883             ioe.initCause(iae);
884             throw ioe;
885         }
886     }
887 
888     /**
889      * Solve a system of composed of a Lower Triangular Matrix {@link RealMatrix}.
890      *
891      * <p>This method is called to solve systems of equations which are of the lower triangular
892      * form. The matrix {@link RealMatrix} is assumed, though not checked, to be in lower triangular
893      * form. The vector {@link RealVector} is overwritten with the solution. The matrix is checked
894      * that it is square and its dimensions match the length of the vector.
895      *
896      * @param rm RealMatrix which is lower triangular
897      * @param b RealVector this is overwritten
898      * @throws DimensionMismatchException if the matrix and vector are not conformable
899      * @throws NonSquareMatrixException if the matrix {@code rm} is not square
900      * @throws MathArithmeticException if the absolute value of one of the diagonal coefficient of
901      *     {@code rm} is lower than {@link Precision#SAFE_MIN}
902      */
solveLowerTriangularSystem(RealMatrix rm, RealVector b)903     public static void solveLowerTriangularSystem(RealMatrix rm, RealVector b)
904             throws DimensionMismatchException, MathArithmeticException, NonSquareMatrixException {
905         if ((rm == null) || (b == null) || (rm.getRowDimension() != b.getDimension())) {
906             throw new DimensionMismatchException(
907                     (rm == null) ? 0 : rm.getRowDimension(), (b == null) ? 0 : b.getDimension());
908         }
909         if (rm.getColumnDimension() != rm.getRowDimension()) {
910             throw new NonSquareMatrixException(rm.getRowDimension(), rm.getColumnDimension());
911         }
912         int rows = rm.getRowDimension();
913         for (int i = 0; i < rows; i++) {
914             double diag = rm.getEntry(i, i);
915             if (FastMath.abs(diag) < Precision.SAFE_MIN) {
916                 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR);
917             }
918             double bi = b.getEntry(i) / diag;
919             b.setEntry(i, bi);
920             for (int j = i + 1; j < rows; j++) {
921                 b.setEntry(j, b.getEntry(j) - bi * rm.getEntry(j, i));
922             }
923         }
924     }
925 
926     /**
927      * Solver a system composed of an Upper Triangular Matrix {@link RealMatrix}.
928      *
929      * <p>This method is called to solve systems of equations which are of the lower triangular
930      * form. The matrix {@link RealMatrix} is assumed, though not checked, to be in upper triangular
931      * form. The vector {@link RealVector} is overwritten with the solution. The matrix is checked
932      * that it is square and its dimensions match the length of the vector.
933      *
934      * @param rm RealMatrix which is upper triangular
935      * @param b RealVector this is overwritten
936      * @throws DimensionMismatchException if the matrix and vector are not conformable
937      * @throws NonSquareMatrixException if the matrix {@code rm} is not square
938      * @throws MathArithmeticException if the absolute value of one of the diagonal coefficient of
939      *     {@code rm} is lower than {@link Precision#SAFE_MIN}
940      */
solveUpperTriangularSystem(RealMatrix rm, RealVector b)941     public static void solveUpperTriangularSystem(RealMatrix rm, RealVector b)
942             throws DimensionMismatchException, MathArithmeticException, NonSquareMatrixException {
943         if ((rm == null) || (b == null) || (rm.getRowDimension() != b.getDimension())) {
944             throw new DimensionMismatchException(
945                     (rm == null) ? 0 : rm.getRowDimension(), (b == null) ? 0 : b.getDimension());
946         }
947         if (rm.getColumnDimension() != rm.getRowDimension()) {
948             throw new NonSquareMatrixException(rm.getRowDimension(), rm.getColumnDimension());
949         }
950         int rows = rm.getRowDimension();
951         for (int i = rows - 1; i > -1; i--) {
952             double diag = rm.getEntry(i, i);
953             if (FastMath.abs(diag) < Precision.SAFE_MIN) {
954                 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR);
955             }
956             double bi = b.getEntry(i) / diag;
957             b.setEntry(i, bi);
958             for (int j = i - 1; j > -1; j--) {
959                 b.setEntry(j, b.getEntry(j) - bi * rm.getEntry(j, i));
960             }
961         }
962     }
963 
964     /**
965      * Computes the inverse of the given matrix by splitting it into 4 sub-matrices.
966      *
967      * @param m Matrix whose inverse must be computed.
968      * @param splitIndex Index that determines the "split" line and column. The element
969      *     corresponding to this index will part of the upper-left sub-matrix.
970      * @return the inverse of {@code m}.
971      * @throws NonSquareMatrixException if {@code m} is not square.
972      */
blockInverse(RealMatrix m, int splitIndex)973     public static RealMatrix blockInverse(RealMatrix m, int splitIndex) {
974         final int n = m.getRowDimension();
975         if (m.getColumnDimension() != n) {
976             throw new NonSquareMatrixException(m.getRowDimension(), m.getColumnDimension());
977         }
978 
979         final int splitIndex1 = splitIndex + 1;
980 
981         final RealMatrix a = m.getSubMatrix(0, splitIndex, 0, splitIndex);
982         final RealMatrix b = m.getSubMatrix(0, splitIndex, splitIndex1, n - 1);
983         final RealMatrix c = m.getSubMatrix(splitIndex1, n - 1, 0, splitIndex);
984         final RealMatrix d = m.getSubMatrix(splitIndex1, n - 1, splitIndex1, n - 1);
985 
986         final SingularValueDecomposition aDec = new SingularValueDecomposition(a);
987         final DecompositionSolver aSolver = aDec.getSolver();
988         if (!aSolver.isNonSingular()) {
989             throw new SingularMatrixException();
990         }
991         final RealMatrix aInv = aSolver.getInverse();
992 
993         final SingularValueDecomposition dDec = new SingularValueDecomposition(d);
994         final DecompositionSolver dSolver = dDec.getSolver();
995         if (!dSolver.isNonSingular()) {
996             throw new SingularMatrixException();
997         }
998         final RealMatrix dInv = dSolver.getInverse();
999 
1000         final RealMatrix tmp1 = a.subtract(b.multiply(dInv).multiply(c));
1001         final SingularValueDecomposition tmp1Dec = new SingularValueDecomposition(tmp1);
1002         final DecompositionSolver tmp1Solver = tmp1Dec.getSolver();
1003         if (!tmp1Solver.isNonSingular()) {
1004             throw new SingularMatrixException();
1005         }
1006         final RealMatrix result00 = tmp1Solver.getInverse();
1007 
1008         final RealMatrix tmp2 = d.subtract(c.multiply(aInv).multiply(b));
1009         final SingularValueDecomposition tmp2Dec = new SingularValueDecomposition(tmp2);
1010         final DecompositionSolver tmp2Solver = tmp2Dec.getSolver();
1011         if (!tmp2Solver.isNonSingular()) {
1012             throw new SingularMatrixException();
1013         }
1014         final RealMatrix result11 = tmp2Solver.getInverse();
1015 
1016         final RealMatrix result01 = aInv.multiply(b).multiply(result11).scalarMultiply(-1);
1017         final RealMatrix result10 = dInv.multiply(c).multiply(result00).scalarMultiply(-1);
1018 
1019         final RealMatrix result = new Array2DRowRealMatrix(n, n);
1020         result.setSubMatrix(result00.getData(), 0, 0);
1021         result.setSubMatrix(result01.getData(), 0, splitIndex1);
1022         result.setSubMatrix(result10.getData(), splitIndex1, 0);
1023         result.setSubMatrix(result11.getData(), splitIndex1, splitIndex1);
1024 
1025         return result;
1026     }
1027 
1028     /**
1029      * Computes the inverse of the given matrix.
1030      *
1031      * <p>By default, the inverse of the matrix is computed using the QR-decomposition, unless a
1032      * more efficient method can be determined for the input matrix.
1033      *
1034      * <p>Note: this method will use a singularity threshold of 0, use {@link #inverse(RealMatrix,
1035      * double)} if a different threshold is needed.
1036      *
1037      * @param matrix Matrix whose inverse shall be computed
1038      * @return the inverse of {@code matrix}
1039      * @throws NullArgumentException if {@code matrix} is {@code null}
1040      * @throws SingularMatrixException if m is singular
1041      * @throws NonSquareMatrixException if matrix is not square
1042      * @since 3.3
1043      */
inverse(RealMatrix matrix)1044     public static RealMatrix inverse(RealMatrix matrix)
1045             throws NullArgumentException, SingularMatrixException, NonSquareMatrixException {
1046         return inverse(matrix, 0);
1047     }
1048 
1049     /**
1050      * Computes the inverse of the given matrix.
1051      *
1052      * <p>By default, the inverse of the matrix is computed using the QR-decomposition, unless a
1053      * more efficient method can be determined for the input matrix.
1054      *
1055      * @param matrix Matrix whose inverse shall be computed
1056      * @param threshold Singularity threshold
1057      * @return the inverse of {@code m}
1058      * @throws NullArgumentException if {@code matrix} is {@code null}
1059      * @throws SingularMatrixException if matrix is singular
1060      * @throws NonSquareMatrixException if matrix is not square
1061      * @since 3.3
1062      */
inverse(RealMatrix matrix, double threshold)1063     public static RealMatrix inverse(RealMatrix matrix, double threshold)
1064             throws NullArgumentException, SingularMatrixException, NonSquareMatrixException {
1065 
1066         MathUtils.checkNotNull(matrix);
1067 
1068         if (!matrix.isSquare()) {
1069             throw new NonSquareMatrixException(
1070                     matrix.getRowDimension(), matrix.getColumnDimension());
1071         }
1072 
1073         if (matrix instanceof DiagonalMatrix) {
1074             return ((DiagonalMatrix) matrix).inverse(threshold);
1075         } else {
1076             QRDecomposition decomposition = new QRDecomposition(matrix, threshold);
1077             return decomposition.getSolver().getInverse();
1078         }
1079     }
1080 }
1081