1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_DATAPARALLEL_SPMD_EXPANDER_H_ 17 #define TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_DATAPARALLEL_SPMD_EXPANDER_H_ 18 #include <utility> 19 20 #include "llvm/ADT/DenseMap.h" 21 #include "tensorflow/dtensor/mlir/spmd_expander.h" 22 23 namespace tensorflow { 24 namespace dtensor { 25 26 // General SPMD Expander for data parallel ops. 27 28 // We define data parallel ops as ops that have tensors possibly with a batch 29 // dimension. Assumes batch dimensions start from the left. Tensors may 30 // may have multiple batch dimensions, including zero 31 class DataparallelSPMDExpander : public SPMDExpanderBase { 32 protected: 33 // These maps contain {arg_index, non_batch_rank} 34 // Example is for TF:FFT2D, the batchable_operands and batchable_outputs has 35 // {0, 2} because the first argument is batchable and the last 2 dimensions 36 // are the non-batch dimensions 37 llvm::DenseMap<int, int> batchable_operands_; 38 llvm::DenseMap<int, int> batchable_outputs_; 39 StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override; 40 41 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 42 mlir::Operation* op, 43 const llvm::DenseMap<int, Layout>& input_layouts) override; 44 45 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 46 mlir::Operation* op, 47 const llvm::DenseMap<int, Layout>& output_layouts) override; 48 49 public: DataparallelSPMDExpander(llvm::DenseMap<int,int> batchable_operands,llvm::DenseMap<int,int> batchable_outputs)50 explicit DataparallelSPMDExpander(llvm::DenseMap<int, int> batchable_operands, 51 llvm::DenseMap<int, int> batchable_outputs) 52 : batchable_operands_(std::move(batchable_operands)), 53 batchable_outputs_(std::move(batchable_outputs)) {} 54 55 private: 56 // Relayouts all operands and outputs with a batch dimensions to a batch 57 // sharded layout. This should only be called when there is at least one 58 // batch sharded operand or batch sharded output 59 StatusOr<mlir::Operation*> RelayoutOperandsAndOutputs( 60 mlir::Operation* op, const std::vector<Layout>& operand_layouts, 61 const std::vector<Layout>& output_layouts); 62 }; 63 } // namespace dtensor 64 } // namespace tensorflow 65 66 #endif // TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_DATAPARALLEL_SPMD_EXPANDER_H_ 67