1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_ 17 #define TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_ 18 19 #include <string> 20 21 #include "absl/types/optional.h" 22 #include "mlir/IR/Builders.h" // from @llvm-project 23 #include "mlir/IR/Operation.h" // from @llvm-project 24 #include "mlir/IR/UseDefLists.h" // from @llvm-project 25 #include "tensorflow/core/framework/registration/registration.h" 26 #include "tensorflow/dtensor/cc/dstatus.h" 27 #include "tensorflow/dtensor/cc/tensor_layout.h" 28 #include "tensorflow/dtensor/mlir/spmd_expander_common.h" 29 30 namespace tensorflow { 31 namespace dtensor { 32 33 // Base class for handling SPMD expansion of a MLIR TF Operation. 34 class SPMDExpanderBase { 35 public: ~SPMDExpanderBase()36 virtual ~SPMDExpanderBase() {} 37 38 // Converts `op` to a SPMD expanded form. SPMD expansion logic is 39 // a function of op type, op output's layout, and layout of op's 40 // inputs. Must return the `op` that is expanded as the final return value. 41 virtual StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) = 0; 42 43 // Layout propagation functions. 44 // 45 // During the layout algorithm, for each op output we compute a layout by 46 // merging the current layout request from the op producing the output and the 47 // layout requests from the ops consuming the output. These merged layouts 48 // represent the current state of layouts over the entire mlir module. 49 // 50 // For an op, if any of the merged layouts for the inputs or output are 51 // updated, the ComputeLayoutForward and ComputeLayoutBackward functions will 52 // be called with all the updated layout maps populated. 53 // 54 // ComputeLayoutForward should take the input layouts and determine which 55 // output layout these inputs would produce. Likewise, ComputeLayoutBackward 56 // should take the output layouts and determine the what layouts to propagate 57 // to the inputs. 58 // 59 // In both cases the functions should choose layouts that reduce the amount of 60 // cross device communication for the op. 61 // 62 // ComputeLayoutForward should not take into account the current output 63 // layout(s) when computing the new ones. The merge algorithm will decide what 64 // to do. There are only a very few cases where the current output layout may 65 // need to propagated again, in which case those ops can override the 66 // expanded ComputeLayout* functions. This similarly applies to 67 // ComputeLayoutBackward. 68 // 69 // Note that for some ops, where the input layout does not determine output 70 // layout (and visa versa), it is acceptable to either return a replicated 71 // layout. E.g. for tf.Fill, ComputeLayoutForward can return a replicated 72 // output layout and if a consumer requests a more sharded layout, then the 73 // layout algorithm will merge the requests, resulting in the more sharded 74 // layout. 75 76 // Computes output layout(s) of `op` based on the current `input_layouts` 77 // inferred from inputs of `op`. The `input_layouts` parameter maps input 78 // indices to the corresponding layouts. It may be empty if the op has no 79 // operands or if no input layouts have been inferred yet. 80 virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 81 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts); 82 83 // Computes output layout(s) of `op` based on the current `input_layouts` and 84 // `output_layouts` inferred from the inputs and outputs of `op`. Both 85 // parameters maps input/output indices to the corresponding layouts. Either 86 // may be empty. 87 // 88 // NOTE: The other ComputeLayoutForward function should be preferred since in 89 // most cases the output layouts are only computed based on the input layouts. 90 virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 91 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts, 92 const llvm::DenseMap<int, Layout>& output_layouts); 93 94 // Computes input layout(s) of `op` based on the current `output_layouts` 95 // inferred from outputs of `op`. The `output_layouts` parameter maps output 96 // indices to the corresponding layouts. It may be empty if the op has no 97 // outputs or if no output layouts have been inferred yet. 98 virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 99 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts); 100 101 // Computes input layout(s) of `op` based on the current `output_layouts` and 102 // `input_layouts` inferred from the outputs and inputs of `op`. Both 103 // parameters maps input/output indices to the corresponding layouts. Either 104 // may be empty. 105 // 106 // NOTE: The other ComputeLayoutBackward function should be preferred since in 107 // most cases the input layouts are only computed based on the output layouts. 108 virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 109 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts, 110 const llvm::DenseMap<int, Layout>& output_layouts); 111 112 // Run ExpandOp() and set layout from the computed layout from original op. 113 // Returns the expanded op in output. 114 Status ExpandOpAndSetLayout(mlir::Operation* op, mlir::Operation** output); 115 }; 116 117 // Computes the SPMD expansion for `op`. 118 // 119 // Prior to this call, all inputs to `op` have been lowered to local operations 120 // & shapes. The lowered op must emit a type compatible with the local shape. 121 Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output); 122 123 // A registry of SPMD expanders. This map is statically stored and initialized 124 // with all the registered SPMD expanders. 125 class SPMDExpanderRegistry { 126 public: 127 ~SPMDExpanderRegistry() = default; 128 129 // A singleton available at startup. 130 static SPMDExpanderRegistry* Global(); 131 132 // Returns the expansion for the given operation (or nullptr if no expansion 133 // has been registered). 134 SPMDExpanderBase* GetPropagateFnForOp(mlir::Operation* op); 135 136 // Registers an expander for the provided opName. 137 InitOnStartupMarker RegisterPropagateFn( 138 std::string opName, std::unique_ptr<SPMDExpanderBase> prop); 139 140 private: 141 absl::flat_hash_map<std::string, std::unique_ptr<SPMDExpanderBase>> 142 op_to_propagate_fn_map_; 143 }; 144 145 #define REGISTER_SPMD(name, op, prop, ...) \ 146 static ::tensorflow::InitOnStartupMarker const spmd_##name = \ 147 InitOnStartupMarker{} \ 148 << SPMDExpanderRegistry::Global()->RegisterPropagateFn( \ 149 mlir::op ::getOperationName().str(), \ 150 std::make_unique<prop>(__VA_ARGS__)) 151 152 } // namespace dtensor 153 } // namespace tensorflow 154 155 #endif // TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_ 156