xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/spmd_expander.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_
17 #define TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_
18 
19 #include <string>
20 
21 #include "absl/types/optional.h"
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
25 #include "tensorflow/core/framework/registration/registration.h"
26 #include "tensorflow/dtensor/cc/dstatus.h"
27 #include "tensorflow/dtensor/cc/tensor_layout.h"
28 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
29 
30 namespace tensorflow {
31 namespace dtensor {
32 
33 // Base class for handling SPMD expansion of a MLIR TF Operation.
34 class SPMDExpanderBase {
35  public:
~SPMDExpanderBase()36   virtual ~SPMDExpanderBase() {}
37 
38   // Converts `op` to a SPMD expanded form. SPMD expansion logic is
39   // a function of op type, op output's layout, and layout of op's
40   // inputs. Must return the `op` that is expanded as the final return value.
41   virtual StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) = 0;
42 
43   // Layout propagation functions.
44   //
45   // During the layout algorithm, for each op output we compute a layout by
46   // merging the current layout request from the op producing the output and the
47   // layout requests from the ops consuming the output. These merged layouts
48   // represent the current state of layouts over the entire mlir module.
49   //
50   // For an op, if any of the merged layouts for the inputs or output are
51   // updated, the ComputeLayoutForward and ComputeLayoutBackward functions will
52   // be called with all the updated layout maps populated.
53   //
54   // ComputeLayoutForward should take the input layouts and determine which
55   // output layout these inputs would produce. Likewise, ComputeLayoutBackward
56   // should take the output layouts and determine the what layouts to propagate
57   // to the inputs.
58   //
59   // In both cases the functions should choose layouts that reduce the amount of
60   // cross device communication for the op.
61   //
62   // ComputeLayoutForward should not take into account the current output
63   // layout(s) when computing the new ones. The merge algorithm will decide what
64   // to do. There are only a very few cases where the current output layout may
65   // need to propagated again, in which case those ops can override the
66   // expanded ComputeLayout* functions. This similarly applies to
67   // ComputeLayoutBackward.
68   //
69   // Note that for some ops, where the input layout does not determine output
70   // layout (and visa versa), it is acceptable to either return a replicated
71   // layout. E.g. for tf.Fill, ComputeLayoutForward can return a replicated
72   // output layout and if a consumer requests a more sharded layout, then the
73   // layout algorithm will merge the requests, resulting in the more sharded
74   // layout.
75 
76   // Computes output layout(s) of `op` based on the current `input_layouts`
77   // inferred from inputs of `op`. The `input_layouts` parameter maps input
78   // indices to the corresponding layouts. It may be empty if the op has no
79   // operands or if no input layouts have been inferred yet.
80   virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
81       mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts);
82 
83   // Computes output layout(s) of `op` based on the current `input_layouts` and
84   // `output_layouts` inferred from the inputs and outputs of `op`. Both
85   // parameters maps input/output indices to the corresponding layouts. Either
86   // may be empty.
87   //
88   // NOTE: The other ComputeLayoutForward function should be preferred since in
89   // most cases the output layouts are only computed based on the input layouts.
90   virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
91       mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
92       const llvm::DenseMap<int, Layout>& output_layouts);
93 
94   // Computes input layout(s) of `op` based on the current `output_layouts`
95   // inferred from outputs of `op`. The `output_layouts` parameter maps output
96   // indices to the corresponding layouts. It may be empty if the op has no
97   // outputs or if no output layouts have been inferred yet.
98   virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
99       mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts);
100 
101   // Computes input layout(s) of `op` based on the current `output_layouts` and
102   // `input_layouts` inferred from the outputs and inputs of `op`. Both
103   // parameters maps input/output indices to the corresponding layouts. Either
104   // may be empty.
105   //
106   // NOTE: The other ComputeLayoutBackward function should be preferred since in
107   // most cases the input layouts are only computed based on the output layouts.
108   virtual StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
109       mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
110       const llvm::DenseMap<int, Layout>& output_layouts);
111 
112   // Run ExpandOp() and set layout from the computed layout from original op.
113   // Returns the expanded op in output.
114   Status ExpandOpAndSetLayout(mlir::Operation* op, mlir::Operation** output);
115 };
116 
117 // Computes the SPMD expansion for `op`.
118 //
119 // Prior to this call, all inputs to `op` have been lowered to local operations
120 // & shapes. The lowered op must emit a type compatible with the local shape.
121 Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output);
122 
123 // A registry of SPMD expanders. This map is statically stored and initialized
124 // with all the registered SPMD expanders.
125 class SPMDExpanderRegistry {
126  public:
127   ~SPMDExpanderRegistry() = default;
128 
129   // A singleton available at startup.
130   static SPMDExpanderRegistry* Global();
131 
132   // Returns the expansion for the given operation (or nullptr if no expansion
133   // has been registered).
134   SPMDExpanderBase* GetPropagateFnForOp(mlir::Operation* op);
135 
136   // Registers an expander for the provided opName.
137   InitOnStartupMarker RegisterPropagateFn(
138       std::string opName, std::unique_ptr<SPMDExpanderBase> prop);
139 
140  private:
141   absl::flat_hash_map<std::string, std::unique_ptr<SPMDExpanderBase>>
142       op_to_propagate_fn_map_;
143 };
144 
145 #define REGISTER_SPMD(name, op, prop, ...)                     \
146   static ::tensorflow::InitOnStartupMarker const spmd_##name = \
147       InitOnStartupMarker{}                                    \
148       << SPMDExpanderRegistry::Global()->RegisterPropagateFn(  \
149              mlir::op ::getOperationName().str(),              \
150              std::make_unique<prop>(__VA_ARGS__))
151 
152 }  // namespace dtensor
153 }  // namespace tensorflow
154 
155 #endif  // TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_H_
156