1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.commons.math3.linear; 19 20 import org.apache.commons.math3.Field; 21 import org.apache.commons.math3.FieldElement; 22 import org.apache.commons.math3.exception.DimensionMismatchException; 23 import org.apache.commons.math3.exception.MathArithmeticException; 24 import org.apache.commons.math3.exception.NoDataException; 25 import org.apache.commons.math3.exception.NullArgumentException; 26 import org.apache.commons.math3.exception.NumberIsTooSmallException; 27 import org.apache.commons.math3.exception.OutOfRangeException; 28 import org.apache.commons.math3.exception.ZeroException; 29 import org.apache.commons.math3.exception.util.LocalizedFormats; 30 import org.apache.commons.math3.fraction.BigFraction; 31 import org.apache.commons.math3.fraction.Fraction; 32 import org.apache.commons.math3.util.FastMath; 33 import org.apache.commons.math3.util.MathArrays; 34 import org.apache.commons.math3.util.MathUtils; 35 import org.apache.commons.math3.util.Precision; 36 37 import java.io.IOException; 38 import java.io.ObjectInputStream; 39 import java.io.ObjectOutputStream; 40 import java.util.Arrays; 41 42 /** A collection of static methods that operate on or return matrices. */ 43 public class MatrixUtils { 44 45 /** 46 * The default format for {@link RealMatrix} objects. 47 * 48 * @since 3.1 49 */ 50 public static final RealMatrixFormat DEFAULT_FORMAT = RealMatrixFormat.getInstance(); 51 52 /** 53 * A format for {@link RealMatrix} objects compatible with octave. 54 * 55 * @since 3.1 56 */ 57 public static final RealMatrixFormat OCTAVE_FORMAT = 58 new RealMatrixFormat("[", "]", "", "", "; ", ", "); 59 60 /** Private constructor. */ MatrixUtils()61 private MatrixUtils() { 62 super(); 63 } 64 65 /** 66 * Returns a {@link RealMatrix} with specified dimensions. 67 * 68 * <p>The type of matrix returned depends on the dimension. Below 2<sup>12</sup> elements (i.e. 69 * 4096 elements or 64×64 for a square matrix) which can be stored in a 32kB array, a 70 * {@link Array2DRowRealMatrix} instance is built. Above this threshold a {@link 71 * BlockRealMatrix} instance is built. 72 * 73 * <p>The matrix elements are all set to 0.0. 74 * 75 * @param rows number of rows of the matrix 76 * @param columns number of columns of the matrix 77 * @return RealMatrix with specified dimensions 78 * @see #createRealMatrix(double[][]) 79 */ createRealMatrix(final int rows, final int columns)80 public static RealMatrix createRealMatrix(final int rows, final int columns) { 81 return (rows * columns <= 4096) 82 ? new Array2DRowRealMatrix(rows, columns) 83 : new BlockRealMatrix(rows, columns); 84 } 85 86 /** 87 * Returns a {@link FieldMatrix} with specified dimensions. 88 * 89 * <p>The type of matrix returned depends on the dimension. Below 2<sup>12</sup> elements (i.e. 90 * 4096 elements or 64×64 for a square matrix), a {@link FieldMatrix} instance is built. 91 * Above this threshold a {@link BlockFieldMatrix} instance is built. 92 * 93 * <p>The matrix elements are all set to field.getZero(). 94 * 95 * @param <T> the type of the field elements 96 * @param field field to which the matrix elements belong 97 * @param rows number of rows of the matrix 98 * @param columns number of columns of the matrix 99 * @return FieldMatrix with specified dimensions 100 * @see #createFieldMatrix(FieldElement[][]) 101 * @since 2.0 102 */ createFieldMatrix( final Field<T> field, final int rows, final int columns)103 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix( 104 final Field<T> field, final int rows, final int columns) { 105 return (rows * columns <= 4096) 106 ? new Array2DRowFieldMatrix<T>(field, rows, columns) 107 : new BlockFieldMatrix<T>(field, rows, columns); 108 } 109 110 /** 111 * Returns a {@link RealMatrix} whose entries are the the values in the the input array. 112 * 113 * <p>The type of matrix returned depends on the dimension. Below 2<sup>12</sup> elements (i.e. 114 * 4096 elements or 64×64 for a square matrix) which can be stored in a 32kB array, a 115 * {@link Array2DRowRealMatrix} instance is built. Above this threshold a {@link 116 * BlockRealMatrix} instance is built. 117 * 118 * <p>The input array is copied, not referenced. 119 * 120 * @param data input array 121 * @return RealMatrix containing the values of the array 122 * @throws org.apache.commons.math3.exception.DimensionMismatchException if {@code data} is not 123 * rectangular (not all rows have the same length). 124 * @throws NoDataException if a row or column is empty. 125 * @throws NullArgumentException if either {@code data} or {@code data[0]} is {@code null}. 126 * @throws DimensionMismatchException if {@code data} is not rectangular. 127 * @see #createRealMatrix(int, int) 128 */ createRealMatrix(double[][] data)129 public static RealMatrix createRealMatrix(double[][] data) 130 throws NullArgumentException, DimensionMismatchException, NoDataException { 131 if (data == null || data[0] == null) { 132 throw new NullArgumentException(); 133 } 134 return (data.length * data[0].length <= 4096) 135 ? new Array2DRowRealMatrix(data) 136 : new BlockRealMatrix(data); 137 } 138 139 /** 140 * Returns a {@link FieldMatrix} whose entries are the the values in the the input array. 141 * 142 * <p>The type of matrix returned depends on the dimension. Below 2<sup>12</sup> elements (i.e. 143 * 4096 elements or 64×64 for a square matrix), a {@link FieldMatrix} instance is built. 144 * Above this threshold a {@link BlockFieldMatrix} instance is built. 145 * 146 * <p>The input array is copied, not referenced. 147 * 148 * @param <T> the type of the field elements 149 * @param data input array 150 * @return a matrix containing the values of the array. 151 * @throws org.apache.commons.math3.exception.DimensionMismatchException if {@code data} is not 152 * rectangular (not all rows have the same length). 153 * @throws NoDataException if a row or column is empty. 154 * @throws NullArgumentException if either {@code data} or {@code data[0]} is {@code null}. 155 * @see #createFieldMatrix(Field, int, int) 156 * @since 2.0 157 */ createFieldMatrix(T[][] data)158 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldMatrix(T[][] data) 159 throws DimensionMismatchException, NoDataException, NullArgumentException { 160 if (data == null || data[0] == null) { 161 throw new NullArgumentException(); 162 } 163 return (data.length * data[0].length <= 4096) 164 ? new Array2DRowFieldMatrix<T>(data) 165 : new BlockFieldMatrix<T>(data); 166 } 167 168 /** 169 * Returns <code>dimension x dimension</code> identity matrix. 170 * 171 * @param dimension dimension of identity matrix to generate 172 * @return identity matrix 173 * @throws IllegalArgumentException if dimension is not positive 174 * @since 1.1 175 */ createRealIdentityMatrix(int dimension)176 public static RealMatrix createRealIdentityMatrix(int dimension) { 177 final RealMatrix m = createRealMatrix(dimension, dimension); 178 for (int i = 0; i < dimension; ++i) { 179 m.setEntry(i, i, 1.0); 180 } 181 return m; 182 } 183 184 /** 185 * Returns <code>dimension x dimension</code> identity matrix. 186 * 187 * @param <T> the type of the field elements 188 * @param field field to which the elements belong 189 * @param dimension dimension of identity matrix to generate 190 * @return identity matrix 191 * @throws IllegalArgumentException if dimension is not positive 192 * @since 2.0 193 */ createFieldIdentityMatrix( final Field<T> field, final int dimension)194 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldIdentityMatrix( 195 final Field<T> field, final int dimension) { 196 final T zero = field.getZero(); 197 final T one = field.getOne(); 198 final T[][] d = MathArrays.buildArray(field, dimension, dimension); 199 for (int row = 0; row < dimension; row++) { 200 final T[] dRow = d[row]; 201 Arrays.fill(dRow, zero); 202 dRow[row] = one; 203 } 204 return new Array2DRowFieldMatrix<T>(field, d, false); 205 } 206 207 /** 208 * Returns a diagonal matrix with specified elements. 209 * 210 * @param diagonal diagonal elements of the matrix (the array elements will be copied) 211 * @return diagonal matrix 212 * @since 2.0 213 */ createRealDiagonalMatrix(final double[] diagonal)214 public static RealMatrix createRealDiagonalMatrix(final double[] diagonal) { 215 final RealMatrix m = createRealMatrix(diagonal.length, diagonal.length); 216 for (int i = 0; i < diagonal.length; ++i) { 217 m.setEntry(i, i, diagonal[i]); 218 } 219 return m; 220 } 221 222 /** 223 * Returns a diagonal matrix with specified elements. 224 * 225 * @param <T> the type of the field elements 226 * @param diagonal diagonal elements of the matrix (the array elements will be copied) 227 * @return diagonal matrix 228 * @since 2.0 229 */ createFieldDiagonalMatrix( final T[] diagonal)230 public static <T extends FieldElement<T>> FieldMatrix<T> createFieldDiagonalMatrix( 231 final T[] diagonal) { 232 final FieldMatrix<T> m = 233 createFieldMatrix(diagonal[0].getField(), diagonal.length, diagonal.length); 234 for (int i = 0; i < diagonal.length; ++i) { 235 m.setEntry(i, i, diagonal[i]); 236 } 237 return m; 238 } 239 240 /** 241 * Creates a {@link RealVector} using the data from the input array. 242 * 243 * @param data the input data 244 * @return a data.length RealVector 245 * @throws NoDataException if {@code data} is empty. 246 * @throws NullArgumentException if {@code data} is {@code null}. 247 */ createRealVector(double[] data)248 public static RealVector createRealVector(double[] data) 249 throws NoDataException, NullArgumentException { 250 if (data == null) { 251 throw new NullArgumentException(); 252 } 253 return new ArrayRealVector(data, true); 254 } 255 256 /** 257 * Creates a {@link FieldVector} using the data from the input array. 258 * 259 * @param <T> the type of the field elements 260 * @param data the input data 261 * @return a data.length FieldVector 262 * @throws NoDataException if {@code data} is empty. 263 * @throws NullArgumentException if {@code data} is {@code null}. 264 * @throws ZeroException if {@code data} has 0 elements 265 */ createFieldVector(final T[] data)266 public static <T extends FieldElement<T>> FieldVector<T> createFieldVector(final T[] data) 267 throws NoDataException, NullArgumentException, ZeroException { 268 if (data == null) { 269 throw new NullArgumentException(); 270 } 271 if (data.length == 0) { 272 throw new ZeroException(LocalizedFormats.VECTOR_MUST_HAVE_AT_LEAST_ONE_ELEMENT); 273 } 274 return new ArrayFieldVector<T>(data[0].getField(), data, true); 275 } 276 277 /** 278 * Create a row {@link RealMatrix} using the data from the input array. 279 * 280 * @param rowData the input row data 281 * @return a 1 x rowData.length RealMatrix 282 * @throws NoDataException if {@code rowData} is empty. 283 * @throws NullArgumentException if {@code rowData} is {@code null}. 284 */ createRowRealMatrix(double[] rowData)285 public static RealMatrix createRowRealMatrix(double[] rowData) 286 throws NoDataException, NullArgumentException { 287 if (rowData == null) { 288 throw new NullArgumentException(); 289 } 290 final int nCols = rowData.length; 291 final RealMatrix m = createRealMatrix(1, nCols); 292 for (int i = 0; i < nCols; ++i) { 293 m.setEntry(0, i, rowData[i]); 294 } 295 return m; 296 } 297 298 /** 299 * Create a row {@link FieldMatrix} using the data from the input array. 300 * 301 * @param <T> the type of the field elements 302 * @param rowData the input row data 303 * @return a 1 x rowData.length FieldMatrix 304 * @throws NoDataException if {@code rowData} is empty. 305 * @throws NullArgumentException if {@code rowData} is {@code null}. 306 */ createRowFieldMatrix(final T[] rowData)307 public static <T extends FieldElement<T>> FieldMatrix<T> createRowFieldMatrix(final T[] rowData) 308 throws NoDataException, NullArgumentException { 309 if (rowData == null) { 310 throw new NullArgumentException(); 311 } 312 final int nCols = rowData.length; 313 if (nCols == 0) { 314 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_COLUMN); 315 } 316 final FieldMatrix<T> m = createFieldMatrix(rowData[0].getField(), 1, nCols); 317 for (int i = 0; i < nCols; ++i) { 318 m.setEntry(0, i, rowData[i]); 319 } 320 return m; 321 } 322 323 /** 324 * Creates a column {@link RealMatrix} using the data from the input array. 325 * 326 * @param columnData the input column data 327 * @return a columnData x 1 RealMatrix 328 * @throws NoDataException if {@code columnData} is empty. 329 * @throws NullArgumentException if {@code columnData} is {@code null}. 330 */ createColumnRealMatrix(double[] columnData)331 public static RealMatrix createColumnRealMatrix(double[] columnData) 332 throws NoDataException, NullArgumentException { 333 if (columnData == null) { 334 throw new NullArgumentException(); 335 } 336 final int nRows = columnData.length; 337 final RealMatrix m = createRealMatrix(nRows, 1); 338 for (int i = 0; i < nRows; ++i) { 339 m.setEntry(i, 0, columnData[i]); 340 } 341 return m; 342 } 343 344 /** 345 * Creates a column {@link FieldMatrix} using the data from the input array. 346 * 347 * @param <T> the type of the field elements 348 * @param columnData the input column data 349 * @return a columnData x 1 FieldMatrix 350 * @throws NoDataException if {@code data} is empty. 351 * @throws NullArgumentException if {@code columnData} is {@code null}. 352 */ createColumnFieldMatrix( final T[] columnData)353 public static <T extends FieldElement<T>> FieldMatrix<T> createColumnFieldMatrix( 354 final T[] columnData) throws NoDataException, NullArgumentException { 355 if (columnData == null) { 356 throw new NullArgumentException(); 357 } 358 final int nRows = columnData.length; 359 if (nRows == 0) { 360 throw new NoDataException(LocalizedFormats.AT_LEAST_ONE_ROW); 361 } 362 final FieldMatrix<T> m = createFieldMatrix(columnData[0].getField(), nRows, 1); 363 for (int i = 0; i < nRows; ++i) { 364 m.setEntry(i, 0, columnData[i]); 365 } 366 return m; 367 } 368 369 /** 370 * Checks whether a matrix is symmetric, within a given relative tolerance. 371 * 372 * @param matrix Matrix to check. 373 * @param relativeTolerance Tolerance of the symmetry check. 374 * @param raiseException If {@code true}, an exception will be raised if the matrix is not 375 * symmetric. 376 * @return {@code true} if {@code matrix} is symmetric. 377 * @throws NonSquareMatrixException if the matrix is not square. 378 * @throws NonSymmetricMatrixException if the matrix is not symmetric. 379 */ isSymmetricInternal( RealMatrix matrix, double relativeTolerance, boolean raiseException)380 private static boolean isSymmetricInternal( 381 RealMatrix matrix, double relativeTolerance, boolean raiseException) { 382 final int rows = matrix.getRowDimension(); 383 if (rows != matrix.getColumnDimension()) { 384 if (raiseException) { 385 throw new NonSquareMatrixException(rows, matrix.getColumnDimension()); 386 } else { 387 return false; 388 } 389 } 390 for (int i = 0; i < rows; i++) { 391 for (int j = i + 1; j < rows; j++) { 392 final double mij = matrix.getEntry(i, j); 393 final double mji = matrix.getEntry(j, i); 394 if (FastMath.abs(mij - mji) 395 > FastMath.max(FastMath.abs(mij), FastMath.abs(mji)) * relativeTolerance) { 396 if (raiseException) { 397 throw new NonSymmetricMatrixException(i, j, relativeTolerance); 398 } else { 399 return false; 400 } 401 } 402 } 403 } 404 return true; 405 } 406 407 /** 408 * Checks whether a matrix is symmetric. 409 * 410 * @param matrix Matrix to check. 411 * @param eps Relative tolerance. 412 * @throws NonSquareMatrixException if the matrix is not square. 413 * @throws NonSymmetricMatrixException if the matrix is not symmetric. 414 * @since 3.1 415 */ checkSymmetric(RealMatrix matrix, double eps)416 public static void checkSymmetric(RealMatrix matrix, double eps) { 417 isSymmetricInternal(matrix, eps, true); 418 } 419 420 /** 421 * Checks whether a matrix is symmetric. 422 * 423 * @param matrix Matrix to check. 424 * @param eps Relative tolerance. 425 * @return {@code true} if {@code matrix} is symmetric. 426 * @since 3.1 427 */ isSymmetric(RealMatrix matrix, double eps)428 public static boolean isSymmetric(RealMatrix matrix, double eps) { 429 return isSymmetricInternal(matrix, eps, false); 430 } 431 432 /** 433 * Check if matrix indices are valid. 434 * 435 * @param m Matrix. 436 * @param row Row index to check. 437 * @param column Column index to check. 438 * @throws OutOfRangeException if {@code row} or {@code column} is not a valid index. 439 */ checkMatrixIndex(final AnyMatrix m, final int row, final int column)440 public static void checkMatrixIndex(final AnyMatrix m, final int row, final int column) 441 throws OutOfRangeException { 442 checkRowIndex(m, row); 443 checkColumnIndex(m, column); 444 } 445 446 /** 447 * Check if a row index is valid. 448 * 449 * @param m Matrix. 450 * @param row Row index to check. 451 * @throws OutOfRangeException if {@code row} is not a valid index. 452 */ checkRowIndex(final AnyMatrix m, final int row)453 public static void checkRowIndex(final AnyMatrix m, final int row) throws OutOfRangeException { 454 if (row < 0 || row >= m.getRowDimension()) { 455 throw new OutOfRangeException( 456 LocalizedFormats.ROW_INDEX, row, 0, m.getRowDimension() - 1); 457 } 458 } 459 460 /** 461 * Check if a column index is valid. 462 * 463 * @param m Matrix. 464 * @param column Column index to check. 465 * @throws OutOfRangeException if {@code column} is not a valid index. 466 */ checkColumnIndex(final AnyMatrix m, final int column)467 public static void checkColumnIndex(final AnyMatrix m, final int column) 468 throws OutOfRangeException { 469 if (column < 0 || column >= m.getColumnDimension()) { 470 throw new OutOfRangeException( 471 LocalizedFormats.COLUMN_INDEX, column, 0, m.getColumnDimension() - 1); 472 } 473 } 474 475 /** 476 * Check if submatrix ranges indices are valid. Rows and columns are indicated counting from 0 477 * to {@code n - 1}. 478 * 479 * @param m Matrix. 480 * @param startRow Initial row index. 481 * @param endRow Final row index. 482 * @param startColumn Initial column index. 483 * @param endColumn Final column index. 484 * @throws OutOfRangeException if the indices are invalid. 485 * @throws NumberIsTooSmallException if {@code endRow < startRow} or {@code endColumn < 486 * startColumn}. 487 */ checkSubMatrixIndex( final AnyMatrix m, final int startRow, final int endRow, final int startColumn, final int endColumn)488 public static void checkSubMatrixIndex( 489 final AnyMatrix m, 490 final int startRow, 491 final int endRow, 492 final int startColumn, 493 final int endColumn) 494 throws NumberIsTooSmallException, OutOfRangeException { 495 checkRowIndex(m, startRow); 496 checkRowIndex(m, endRow); 497 if (endRow < startRow) { 498 throw new NumberIsTooSmallException( 499 LocalizedFormats.INITIAL_ROW_AFTER_FINAL_ROW, endRow, startRow, false); 500 } 501 502 checkColumnIndex(m, startColumn); 503 checkColumnIndex(m, endColumn); 504 if (endColumn < startColumn) { 505 throw new NumberIsTooSmallException( 506 LocalizedFormats.INITIAL_COLUMN_AFTER_FINAL_COLUMN, 507 endColumn, 508 startColumn, 509 false); 510 } 511 } 512 513 /** 514 * Check if submatrix ranges indices are valid. Rows and columns are indicated counting from 0 515 * to n-1. 516 * 517 * @param m Matrix. 518 * @param selectedRows Array of row indices. 519 * @param selectedColumns Array of column indices. 520 * @throws NullArgumentException if {@code selectedRows} or {@code selectedColumns} are {@code 521 * null}. 522 * @throws NoDataException if the row or column selections are empty (zero length). 523 * @throws OutOfRangeException if row or column selections are not valid. 524 */ checkSubMatrixIndex( final AnyMatrix m, final int[] selectedRows, final int[] selectedColumns)525 public static void checkSubMatrixIndex( 526 final AnyMatrix m, final int[] selectedRows, final int[] selectedColumns) 527 throws NoDataException, NullArgumentException, OutOfRangeException { 528 if (selectedRows == null) { 529 throw new NullArgumentException(); 530 } 531 if (selectedColumns == null) { 532 throw new NullArgumentException(); 533 } 534 if (selectedRows.length == 0) { 535 throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_ROW_INDEX_ARRAY); 536 } 537 if (selectedColumns.length == 0) { 538 throw new NoDataException(LocalizedFormats.EMPTY_SELECTED_COLUMN_INDEX_ARRAY); 539 } 540 541 for (final int row : selectedRows) { 542 checkRowIndex(m, row); 543 } 544 for (final int column : selectedColumns) { 545 checkColumnIndex(m, column); 546 } 547 } 548 549 /** 550 * Check if matrices are addition compatible. 551 * 552 * @param left Left hand side matrix. 553 * @param right Right hand side matrix. 554 * @throws MatrixDimensionMismatchException if the matrices are not addition compatible. 555 */ checkAdditionCompatible(final AnyMatrix left, final AnyMatrix right)556 public static void checkAdditionCompatible(final AnyMatrix left, final AnyMatrix right) 557 throws MatrixDimensionMismatchException { 558 if ((left.getRowDimension() != right.getRowDimension()) 559 || (left.getColumnDimension() != right.getColumnDimension())) { 560 throw new MatrixDimensionMismatchException( 561 left.getRowDimension(), left.getColumnDimension(), 562 right.getRowDimension(), right.getColumnDimension()); 563 } 564 } 565 566 /** 567 * Check if matrices are subtraction compatible 568 * 569 * @param left Left hand side matrix. 570 * @param right Right hand side matrix. 571 * @throws MatrixDimensionMismatchException if the matrices are not addition compatible. 572 */ checkSubtractionCompatible(final AnyMatrix left, final AnyMatrix right)573 public static void checkSubtractionCompatible(final AnyMatrix left, final AnyMatrix right) 574 throws MatrixDimensionMismatchException { 575 if ((left.getRowDimension() != right.getRowDimension()) 576 || (left.getColumnDimension() != right.getColumnDimension())) { 577 throw new MatrixDimensionMismatchException( 578 left.getRowDimension(), left.getColumnDimension(), 579 right.getRowDimension(), right.getColumnDimension()); 580 } 581 } 582 583 /** 584 * Check if matrices are multiplication compatible 585 * 586 * @param left Left hand side matrix. 587 * @param right Right hand side matrix. 588 * @throws DimensionMismatchException if matrices are not multiplication compatible. 589 */ checkMultiplicationCompatible(final AnyMatrix left, final AnyMatrix right)590 public static void checkMultiplicationCompatible(final AnyMatrix left, final AnyMatrix right) 591 throws DimensionMismatchException { 592 593 if (left.getColumnDimension() != right.getRowDimension()) { 594 throw new DimensionMismatchException( 595 left.getColumnDimension(), right.getRowDimension()); 596 } 597 } 598 599 /** 600 * Convert a {@link FieldMatrix}/{@link Fraction} matrix to a {@link RealMatrix}. 601 * 602 * @param m Matrix to convert. 603 * @return the converted matrix. 604 */ fractionMatrixToRealMatrix(final FieldMatrix<Fraction> m)605 public static Array2DRowRealMatrix fractionMatrixToRealMatrix(final FieldMatrix<Fraction> m) { 606 final FractionMatrixConverter converter = new FractionMatrixConverter(); 607 m.walkInOptimizedOrder(converter); 608 return converter.getConvertedMatrix(); 609 } 610 611 /** Converter for {@link FieldMatrix}/{@link Fraction}. */ 612 private static class FractionMatrixConverter 613 extends DefaultFieldMatrixPreservingVisitor<Fraction> { 614 /** Converted array. */ 615 private double[][] data; 616 617 /** Simple constructor. */ FractionMatrixConverter()618 FractionMatrixConverter() { 619 super(Fraction.ZERO); 620 } 621 622 /** {@inheritDoc} */ 623 @Override start( int rows, int columns, int startRow, int endRow, int startColumn, int endColumn)624 public void start( 625 int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { 626 data = new double[rows][columns]; 627 } 628 629 /** {@inheritDoc} */ 630 @Override visit(int row, int column, Fraction value)631 public void visit(int row, int column, Fraction value) { 632 data[row][column] = value.doubleValue(); 633 } 634 635 /** 636 * Get the converted matrix. 637 * 638 * @return the converted matrix. 639 */ getConvertedMatrix()640 Array2DRowRealMatrix getConvertedMatrix() { 641 return new Array2DRowRealMatrix(data, false); 642 } 643 } 644 645 /** 646 * Convert a {@link FieldMatrix}/{@link BigFraction} matrix to a {@link RealMatrix}. 647 * 648 * @param m Matrix to convert. 649 * @return the converted matrix. 650 */ bigFractionMatrixToRealMatrix( final FieldMatrix<BigFraction> m)651 public static Array2DRowRealMatrix bigFractionMatrixToRealMatrix( 652 final FieldMatrix<BigFraction> m) { 653 final BigFractionMatrixConverter converter = new BigFractionMatrixConverter(); 654 m.walkInOptimizedOrder(converter); 655 return converter.getConvertedMatrix(); 656 } 657 658 /** Converter for {@link FieldMatrix}/{@link BigFraction}. */ 659 private static class BigFractionMatrixConverter 660 extends DefaultFieldMatrixPreservingVisitor<BigFraction> { 661 /** Converted array. */ 662 private double[][] data; 663 664 /** Simple constructor. */ BigFractionMatrixConverter()665 BigFractionMatrixConverter() { 666 super(BigFraction.ZERO); 667 } 668 669 /** {@inheritDoc} */ 670 @Override start( int rows, int columns, int startRow, int endRow, int startColumn, int endColumn)671 public void start( 672 int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { 673 data = new double[rows][columns]; 674 } 675 676 /** {@inheritDoc} */ 677 @Override visit(int row, int column, BigFraction value)678 public void visit(int row, int column, BigFraction value) { 679 data[row][column] = value.doubleValue(); 680 } 681 682 /** 683 * Get the converted matrix. 684 * 685 * @return the converted matrix. 686 */ getConvertedMatrix()687 Array2DRowRealMatrix getConvertedMatrix() { 688 return new Array2DRowRealMatrix(data, false); 689 } 690 } 691 692 /** 693 * Serialize a {@link RealVector}. 694 * 695 * <p>This method is intended to be called from within a private <code>writeObject</code> method 696 * (after a call to <code>oos.defaultWriteObject()</code>) in a class that has a {@link 697 * RealVector} field, which should be declared <code>transient</code>. This way, the default 698 * handling does not serialize the vector (the {@link RealVector} interface is not serializable 699 * by default) but this method does serialize it specifically. 700 * 701 * <p>The following example shows how a simple class with a name and a real vector should be 702 * written: 703 * 704 * <pre><code> 705 * public class NamedVector implements Serializable { 706 * 707 * private final String name; 708 * private final transient RealVector coefficients; 709 * 710 * // omitted constructors, getters ... 711 * 712 * private void writeObject(ObjectOutputStream oos) throws IOException { 713 * oos.defaultWriteObject(); // takes care of name field 714 * MatrixUtils.serializeRealVector(coefficients, oos); 715 * } 716 * 717 * private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 718 * ois.defaultReadObject(); // takes care of name field 719 * MatrixUtils.deserializeRealVector(this, "coefficients", ois); 720 * } 721 * 722 * } 723 * </code></pre> 724 * 725 * @param vector real vector to serialize 726 * @param oos stream where the real vector should be written 727 * @exception IOException if object cannot be written to stream 728 * @see #deserializeRealVector(Object, String, ObjectInputStream) 729 */ serializeRealVector(final RealVector vector, final ObjectOutputStream oos)730 public static void serializeRealVector(final RealVector vector, final ObjectOutputStream oos) 731 throws IOException { 732 final int n = vector.getDimension(); 733 oos.writeInt(n); 734 for (int i = 0; i < n; ++i) { 735 oos.writeDouble(vector.getEntry(i)); 736 } 737 } 738 739 /** 740 * Deserialize a {@link RealVector} field in a class. 741 * 742 * <p>This method is intended to be called from within a private <code>readObject</code> method 743 * (after a call to <code>ois.defaultReadObject()</code>) in a class that has a {@link 744 * RealVector} field, which should be declared <code>transient</code>. This way, the default 745 * handling does not deserialize the vector (the {@link RealVector} interface is not 746 * serializable by default) but this method does deserialize it specifically. 747 * 748 * @param instance instance in which the field must be set up 749 * @param fieldName name of the field within the class (may be private and final) 750 * @param ois stream from which the real vector should be read 751 * @exception ClassNotFoundException if a class in the stream cannot be found 752 * @exception IOException if object cannot be read from the stream 753 * @see #serializeRealVector(RealVector, ObjectOutputStream) 754 */ deserializeRealVector( final Object instance, final String fieldName, final ObjectInputStream ois)755 public static void deserializeRealVector( 756 final Object instance, final String fieldName, final ObjectInputStream ois) 757 throws ClassNotFoundException, IOException { 758 try { 759 760 // read the vector data 761 final int n = ois.readInt(); 762 final double[] data = new double[n]; 763 for (int i = 0; i < n; ++i) { 764 data[i] = ois.readDouble(); 765 } 766 767 // create the instance 768 final RealVector vector = new ArrayRealVector(data, false); 769 770 // set up the field 771 final java.lang.reflect.Field f = instance.getClass().getDeclaredField(fieldName); 772 f.setAccessible(true); 773 f.set(instance, vector); 774 775 } catch (NoSuchFieldException nsfe) { 776 IOException ioe = new IOException(); 777 ioe.initCause(nsfe); 778 throw ioe; 779 } catch (IllegalAccessException iae) { 780 IOException ioe = new IOException(); 781 ioe.initCause(iae); 782 throw ioe; 783 } 784 } 785 786 /** 787 * Serialize a {@link RealMatrix}. 788 * 789 * <p>This method is intended to be called from within a private <code>writeObject</code> method 790 * (after a call to <code>oos.defaultWriteObject()</code>) in a class that has a {@link 791 * RealMatrix} field, which should be declared <code>transient</code>. This way, the default 792 * handling does not serialize the matrix (the {@link RealMatrix} interface is not serializable 793 * by default) but this method does serialize it specifically. 794 * 795 * <p>The following example shows how a simple class with a name and a real matrix should be 796 * written: 797 * 798 * <pre><code> 799 * public class NamedMatrix implements Serializable { 800 * 801 * private final String name; 802 * private final transient RealMatrix coefficients; 803 * 804 * // omitted constructors, getters ... 805 * 806 * private void writeObject(ObjectOutputStream oos) throws IOException { 807 * oos.defaultWriteObject(); // takes care of name field 808 * MatrixUtils.serializeRealMatrix(coefficients, oos); 809 * } 810 * 811 * private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 812 * ois.defaultReadObject(); // takes care of name field 813 * MatrixUtils.deserializeRealMatrix(this, "coefficients", ois); 814 * } 815 * 816 * } 817 * </code></pre> 818 * 819 * @param matrix real matrix to serialize 820 * @param oos stream where the real matrix should be written 821 * @exception IOException if object cannot be written to stream 822 * @see #deserializeRealMatrix(Object, String, ObjectInputStream) 823 */ serializeRealMatrix(final RealMatrix matrix, final ObjectOutputStream oos)824 public static void serializeRealMatrix(final RealMatrix matrix, final ObjectOutputStream oos) 825 throws IOException { 826 final int n = matrix.getRowDimension(); 827 final int m = matrix.getColumnDimension(); 828 oos.writeInt(n); 829 oos.writeInt(m); 830 for (int i = 0; i < n; ++i) { 831 for (int j = 0; j < m; ++j) { 832 oos.writeDouble(matrix.getEntry(i, j)); 833 } 834 } 835 } 836 837 /** 838 * Deserialize a {@link RealMatrix} field in a class. 839 * 840 * <p>This method is intended to be called from within a private <code>readObject</code> method 841 * (after a call to <code>ois.defaultReadObject()</code>) in a class that has a {@link 842 * RealMatrix} field, which should be declared <code>transient</code>. This way, the default 843 * handling does not deserialize the matrix (the {@link RealMatrix} interface is not 844 * serializable by default) but this method does deserialize it specifically. 845 * 846 * @param instance instance in which the field must be set up 847 * @param fieldName name of the field within the class (may be private and final) 848 * @param ois stream from which the real matrix should be read 849 * @exception ClassNotFoundException if a class in the stream cannot be found 850 * @exception IOException if object cannot be read from the stream 851 * @see #serializeRealMatrix(RealMatrix, ObjectOutputStream) 852 */ deserializeRealMatrix( final Object instance, final String fieldName, final ObjectInputStream ois)853 public static void deserializeRealMatrix( 854 final Object instance, final String fieldName, final ObjectInputStream ois) 855 throws ClassNotFoundException, IOException { 856 try { 857 858 // read the matrix data 859 final int n = ois.readInt(); 860 final int m = ois.readInt(); 861 final double[][] data = new double[n][m]; 862 for (int i = 0; i < n; ++i) { 863 final double[] dataI = data[i]; 864 for (int j = 0; j < m; ++j) { 865 dataI[j] = ois.readDouble(); 866 } 867 } 868 869 // create the instance 870 final RealMatrix matrix = new Array2DRowRealMatrix(data, false); 871 872 // set up the field 873 final java.lang.reflect.Field f = instance.getClass().getDeclaredField(fieldName); 874 f.setAccessible(true); 875 f.set(instance, matrix); 876 877 } catch (NoSuchFieldException nsfe) { 878 IOException ioe = new IOException(); 879 ioe.initCause(nsfe); 880 throw ioe; 881 } catch (IllegalAccessException iae) { 882 IOException ioe = new IOException(); 883 ioe.initCause(iae); 884 throw ioe; 885 } 886 } 887 888 /** 889 * Solve a system of composed of a Lower Triangular Matrix {@link RealMatrix}. 890 * 891 * <p>This method is called to solve systems of equations which are of the lower triangular 892 * form. The matrix {@link RealMatrix} is assumed, though not checked, to be in lower triangular 893 * form. The vector {@link RealVector} is overwritten with the solution. The matrix is checked 894 * that it is square and its dimensions match the length of the vector. 895 * 896 * @param rm RealMatrix which is lower triangular 897 * @param b RealVector this is overwritten 898 * @throws DimensionMismatchException if the matrix and vector are not conformable 899 * @throws NonSquareMatrixException if the matrix {@code rm} is not square 900 * @throws MathArithmeticException if the absolute value of one of the diagonal coefficient of 901 * {@code rm} is lower than {@link Precision#SAFE_MIN} 902 */ solveLowerTriangularSystem(RealMatrix rm, RealVector b)903 public static void solveLowerTriangularSystem(RealMatrix rm, RealVector b) 904 throws DimensionMismatchException, MathArithmeticException, NonSquareMatrixException { 905 if ((rm == null) || (b == null) || (rm.getRowDimension() != b.getDimension())) { 906 throw new DimensionMismatchException( 907 (rm == null) ? 0 : rm.getRowDimension(), (b == null) ? 0 : b.getDimension()); 908 } 909 if (rm.getColumnDimension() != rm.getRowDimension()) { 910 throw new NonSquareMatrixException(rm.getRowDimension(), rm.getColumnDimension()); 911 } 912 int rows = rm.getRowDimension(); 913 for (int i = 0; i < rows; i++) { 914 double diag = rm.getEntry(i, i); 915 if (FastMath.abs(diag) < Precision.SAFE_MIN) { 916 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR); 917 } 918 double bi = b.getEntry(i) / diag; 919 b.setEntry(i, bi); 920 for (int j = i + 1; j < rows; j++) { 921 b.setEntry(j, b.getEntry(j) - bi * rm.getEntry(j, i)); 922 } 923 } 924 } 925 926 /** 927 * Solver a system composed of an Upper Triangular Matrix {@link RealMatrix}. 928 * 929 * <p>This method is called to solve systems of equations which are of the lower triangular 930 * form. The matrix {@link RealMatrix} is assumed, though not checked, to be in upper triangular 931 * form. The vector {@link RealVector} is overwritten with the solution. The matrix is checked 932 * that it is square and its dimensions match the length of the vector. 933 * 934 * @param rm RealMatrix which is upper triangular 935 * @param b RealVector this is overwritten 936 * @throws DimensionMismatchException if the matrix and vector are not conformable 937 * @throws NonSquareMatrixException if the matrix {@code rm} is not square 938 * @throws MathArithmeticException if the absolute value of one of the diagonal coefficient of 939 * {@code rm} is lower than {@link Precision#SAFE_MIN} 940 */ solveUpperTriangularSystem(RealMatrix rm, RealVector b)941 public static void solveUpperTriangularSystem(RealMatrix rm, RealVector b) 942 throws DimensionMismatchException, MathArithmeticException, NonSquareMatrixException { 943 if ((rm == null) || (b == null) || (rm.getRowDimension() != b.getDimension())) { 944 throw new DimensionMismatchException( 945 (rm == null) ? 0 : rm.getRowDimension(), (b == null) ? 0 : b.getDimension()); 946 } 947 if (rm.getColumnDimension() != rm.getRowDimension()) { 948 throw new NonSquareMatrixException(rm.getRowDimension(), rm.getColumnDimension()); 949 } 950 int rows = rm.getRowDimension(); 951 for (int i = rows - 1; i > -1; i--) { 952 double diag = rm.getEntry(i, i); 953 if (FastMath.abs(diag) < Precision.SAFE_MIN) { 954 throw new MathArithmeticException(LocalizedFormats.ZERO_DENOMINATOR); 955 } 956 double bi = b.getEntry(i) / diag; 957 b.setEntry(i, bi); 958 for (int j = i - 1; j > -1; j--) { 959 b.setEntry(j, b.getEntry(j) - bi * rm.getEntry(j, i)); 960 } 961 } 962 } 963 964 /** 965 * Computes the inverse of the given matrix by splitting it into 4 sub-matrices. 966 * 967 * @param m Matrix whose inverse must be computed. 968 * @param splitIndex Index that determines the "split" line and column. The element 969 * corresponding to this index will part of the upper-left sub-matrix. 970 * @return the inverse of {@code m}. 971 * @throws NonSquareMatrixException if {@code m} is not square. 972 */ blockInverse(RealMatrix m, int splitIndex)973 public static RealMatrix blockInverse(RealMatrix m, int splitIndex) { 974 final int n = m.getRowDimension(); 975 if (m.getColumnDimension() != n) { 976 throw new NonSquareMatrixException(m.getRowDimension(), m.getColumnDimension()); 977 } 978 979 final int splitIndex1 = splitIndex + 1; 980 981 final RealMatrix a = m.getSubMatrix(0, splitIndex, 0, splitIndex); 982 final RealMatrix b = m.getSubMatrix(0, splitIndex, splitIndex1, n - 1); 983 final RealMatrix c = m.getSubMatrix(splitIndex1, n - 1, 0, splitIndex); 984 final RealMatrix d = m.getSubMatrix(splitIndex1, n - 1, splitIndex1, n - 1); 985 986 final SingularValueDecomposition aDec = new SingularValueDecomposition(a); 987 final DecompositionSolver aSolver = aDec.getSolver(); 988 if (!aSolver.isNonSingular()) { 989 throw new SingularMatrixException(); 990 } 991 final RealMatrix aInv = aSolver.getInverse(); 992 993 final SingularValueDecomposition dDec = new SingularValueDecomposition(d); 994 final DecompositionSolver dSolver = dDec.getSolver(); 995 if (!dSolver.isNonSingular()) { 996 throw new SingularMatrixException(); 997 } 998 final RealMatrix dInv = dSolver.getInverse(); 999 1000 final RealMatrix tmp1 = a.subtract(b.multiply(dInv).multiply(c)); 1001 final SingularValueDecomposition tmp1Dec = new SingularValueDecomposition(tmp1); 1002 final DecompositionSolver tmp1Solver = tmp1Dec.getSolver(); 1003 if (!tmp1Solver.isNonSingular()) { 1004 throw new SingularMatrixException(); 1005 } 1006 final RealMatrix result00 = tmp1Solver.getInverse(); 1007 1008 final RealMatrix tmp2 = d.subtract(c.multiply(aInv).multiply(b)); 1009 final SingularValueDecomposition tmp2Dec = new SingularValueDecomposition(tmp2); 1010 final DecompositionSolver tmp2Solver = tmp2Dec.getSolver(); 1011 if (!tmp2Solver.isNonSingular()) { 1012 throw new SingularMatrixException(); 1013 } 1014 final RealMatrix result11 = tmp2Solver.getInverse(); 1015 1016 final RealMatrix result01 = aInv.multiply(b).multiply(result11).scalarMultiply(-1); 1017 final RealMatrix result10 = dInv.multiply(c).multiply(result00).scalarMultiply(-1); 1018 1019 final RealMatrix result = new Array2DRowRealMatrix(n, n); 1020 result.setSubMatrix(result00.getData(), 0, 0); 1021 result.setSubMatrix(result01.getData(), 0, splitIndex1); 1022 result.setSubMatrix(result10.getData(), splitIndex1, 0); 1023 result.setSubMatrix(result11.getData(), splitIndex1, splitIndex1); 1024 1025 return result; 1026 } 1027 1028 /** 1029 * Computes the inverse of the given matrix. 1030 * 1031 * <p>By default, the inverse of the matrix is computed using the QR-decomposition, unless a 1032 * more efficient method can be determined for the input matrix. 1033 * 1034 * <p>Note: this method will use a singularity threshold of 0, use {@link #inverse(RealMatrix, 1035 * double)} if a different threshold is needed. 1036 * 1037 * @param matrix Matrix whose inverse shall be computed 1038 * @return the inverse of {@code matrix} 1039 * @throws NullArgumentException if {@code matrix} is {@code null} 1040 * @throws SingularMatrixException if m is singular 1041 * @throws NonSquareMatrixException if matrix is not square 1042 * @since 3.3 1043 */ inverse(RealMatrix matrix)1044 public static RealMatrix inverse(RealMatrix matrix) 1045 throws NullArgumentException, SingularMatrixException, NonSquareMatrixException { 1046 return inverse(matrix, 0); 1047 } 1048 1049 /** 1050 * Computes the inverse of the given matrix. 1051 * 1052 * <p>By default, the inverse of the matrix is computed using the QR-decomposition, unless a 1053 * more efficient method can be determined for the input matrix. 1054 * 1055 * @param matrix Matrix whose inverse shall be computed 1056 * @param threshold Singularity threshold 1057 * @return the inverse of {@code m} 1058 * @throws NullArgumentException if {@code matrix} is {@code null} 1059 * @throws SingularMatrixException if matrix is singular 1060 * @throws NonSquareMatrixException if matrix is not square 1061 * @since 3.3 1062 */ inverse(RealMatrix matrix, double threshold)1063 public static RealMatrix inverse(RealMatrix matrix, double threshold) 1064 throws NullArgumentException, SingularMatrixException, NonSquareMatrixException { 1065 1066 MathUtils.checkNotNull(matrix); 1067 1068 if (!matrix.isSquare()) { 1069 throw new NonSquareMatrixException( 1070 matrix.getRowDimension(), matrix.getColumnDimension()); 1071 } 1072 1073 if (matrix instanceof DiagonalMatrix) { 1074 return ((DiagonalMatrix) matrix).inverse(threshold); 1075 } else { 1076 QRDecomposition decomposition = new QRDecomposition(matrix, threshold); 1077 return decomposition.getSolver().getInverse(); 1078 } 1079 } 1080 } 1081