xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/dataparallel_spmd_expander.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_DATAPARALLEL_SPMD_EXPANDER_H_
17 #define TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_DATAPARALLEL_SPMD_EXPANDER_H_
18 #include <utility>
19 
20 #include "llvm/ADT/DenseMap.h"
21 #include "tensorflow/dtensor/mlir/spmd_expander.h"
22 
23 namespace tensorflow {
24 namespace dtensor {
25 
26 // General SPMD Expander for data parallel ops.
27 
28 // We define data parallel ops as ops that have tensors possibly with a batch
29 // dimension. Assumes batch dimensions start from the left. Tensors may
30 // may have multiple batch dimensions, including zero
31 class DataparallelSPMDExpander : public SPMDExpanderBase {
32  protected:
33   // These maps contain {arg_index, non_batch_rank}
34   // Example is for TF:FFT2D, the batchable_operands and batchable_outputs has
35   // {0, 2} because the first argument is batchable and the last 2 dimensions
36   // are the non-batch dimensions
37   llvm::DenseMap<int, int> batchable_operands_;
38   llvm::DenseMap<int, int> batchable_outputs_;
39   StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override;
40 
41   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
42       mlir::Operation* op,
43       const llvm::DenseMap<int, Layout>& input_layouts) override;
44 
45   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
46       mlir::Operation* op,
47       const llvm::DenseMap<int, Layout>& output_layouts) override;
48 
49  public:
DataparallelSPMDExpander(llvm::DenseMap<int,int> batchable_operands,llvm::DenseMap<int,int> batchable_outputs)50   explicit DataparallelSPMDExpander(llvm::DenseMap<int, int> batchable_operands,
51                                     llvm::DenseMap<int, int> batchable_outputs)
52       : batchable_operands_(std::move(batchable_operands)),
53         batchable_outputs_(std::move(batchable_outputs)) {}
54 
55  private:
56   // Relayouts all operands and outputs with a batch dimensions to a batch
57   // sharded layout. This should only be called when there is at least one
58   // batch sharded operand or batch sharded output
59   StatusOr<mlir::Operation*> RelayoutOperandsAndOutputs(
60       mlir::Operation* op, const std::vector<Layout>& operand_layouts,
61       const std::vector<Layout>& output_layouts);
62 };
63 }  // namespace dtensor
64 }  // namespace tensorflow
65 
66 #endif  // TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_DATAPARALLEL_SPMD_EXPANDER_H_
67