1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "tensorflow/compiler/xla/client/xla_builder.h" 21 #include "tensorflow/compiler/xla/service/op_expander_pass.h" 22 23 namespace xla { 24 25 class TriangularSolveExpander : public OpExpanderPass { 26 public: 27 explicit TriangularSolveExpander(int64_t block_size = 128); 28 name()29 absl::string_view name() const override { 30 return "triangular_solve_expander"; 31 } 32 33 protected: 34 // Should we use direct solves for batched inputs? UseDirectSolves()35 virtual bool UseDirectSolves() const { return true; } 36 37 bool InstructionMatchesPattern(HloInstruction* instruction) override; 38 39 StatusOr<HloInstruction*> ExpandInstruction( 40 HloInstruction* instruction) override; 41 42 // Performs a triangular solve using an algorithm from MAGMA, which inverts 43 // diagonal blocks and multiplies them using matrix multiplications. 44 XlaOp SolveByInvertingDiagonalBlocks(XlaOp a, XlaOp b, bool left_side, 45 bool lower, bool transpose_a, 46 bool conjugate_a, bool unit_diagonal, 47 PrecisionConfig::Precision precision); 48 49 // Helper function used by SolveByInvertingDiagonalBlocks 50 virtual XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower_triangular, 51 PrecisionConfig::Precision precision); 52 53 // Performs a direct triangular solve, suitable for case with small matrices 54 // or with large batch. 55 XlaOp SolveDirectly(XlaOp a, XlaOp b, bool left_side, bool lower, 56 bool transpose_a, bool conjugate_a, bool unit_diagonal, 57 PrecisionConfig::Precision precision); 58 59 XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, 60 bool transpose_a, bool conjugate_a, 61 bool unit_diagonal, int64_t block_size, 62 PrecisionConfig::Precision precision); 63 64 private: 65 // Block size for BuildTriangularSolve 66 const int64_t block_size_; 67 // Mapping from op signatures to existing computations. 68 absl::flat_hash_map<std::string, HloComputation*> computation_cache_; 69 }; 70 71 } // namespace xla 72 73 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ 74