1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/expansions/cumsum_spmd_expander.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <string>
21 #include <utility>
22
23 #include "llvm/Support/FormatVariadic.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/dtensor/cc/dstatus.h"
26 #include "tensorflow/dtensor/mlir/collectives.h"
27 #include "tensorflow/dtensor/mlir/layout_parsing.h"
28 #include "tensorflow/dtensor/mlir/op_utils.h"
29 #include "tensorflow/dtensor/mlir/shape_utils.h"
30 #include "tensorflow/dtensor/mlir/value_utils.h"
31
32 namespace tensorflow {
33 namespace dtensor {
34
35 namespace {
36
37 // Extract `axis` tensor from Cumsum op and return it's positive value, since
38 // it can be a negative index.
GetAxisDimension(mlir::Operation * op)39 StatusOr<int64_t> GetAxisDimension(mlir::Operation* op) {
40 auto cumsum = llvm::dyn_cast<mlir::TF::CumsumOp>(op);
41 if (cumsum == nullptr) {
42 return errors::Internal(
43 absl::StrCat("Expected Cumsum op but got : ", OpName(op)).c_str());
44 }
45 TF_ASSIGN_OR_RETURN(int64_t axis_dim,
46 ExtractConstIntFromValue(cumsum.axis()));
47 int64_t tensor_rank = ValueRank(cumsum.x());
48 // Axis can be in range [-tensor_rank, tensor_rank), so we add tensor_rank
49 // to wrap it around.
50 if (axis_dim >= -tensor_rank && axis_dim < 0) {
51 axis_dim += tensor_rank;
52 } else if (axis_dim < -tensor_rank || axis_dim >= tensor_rank) {
53 return errors::InvalidArgument(
54 "Invalid axis; expected a value in [-tensor_rank, tensor_rank)");
55 }
56 return axis_dim;
57 }
58
59 } // namespace
60
ExpandOp(mlir::Operation * op)61 StatusOr<mlir::Operation*> CumsumSPMDExpander::ExpandOp(mlir::Operation* op) {
62 StatusOr<int64_t> axis_dim = GetAxisDimension(op);
63 if (!axis_dim.ok()) return axis_dim.status();
64
65 TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
66 assert(output_layout);
67
68 // Our intermediate computation layout is the output layout with
69 // the axis dimension replicated. So set both the operand and output layout
70 // to this intermediate layout.
71 Layout intermediate_layout = output_layout->GetLayoutWithReducedDims(
72 {axis_dim.ValueOrDie()}, /*keep_dims=*/true);
73
74 // Relayout operand to intermediate layout.
75 mlir::OpBuilder builder(op);
76 const auto operand = op->getOperand(0);
77 TF_ASSIGN_OR_RETURN(auto operand_layout, ExtractLayoutFromOperand(operand));
78 if (!operand_layout)
79 return errors::InvalidArgument(
80 "input layout of Cumsum op must be known before SPMD "
81 "expansion.");
82
83 TF_ASSIGN_OR_RETURN(
84 const auto new_operand,
85 EmitRelayout(operand, operand_layout.value(), intermediate_layout));
86 op->setOperand(0, new_operand);
87
88 op = InferSPMDExpandedLocalShape(op);
89
90 // Relayout output to intermediate layout.
91 llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
92 builder.setInsertionPointAfter(op);
93 TF_ASSIGN_OR_RETURN(auto final_output,
94 EmitRelayout(op->getOpResult(0), intermediate_layout,
95 output_layout.value(), &newly_created_ops));
96 op->getOpResult(0).replaceAllUsesExcept(final_output, newly_created_ops);
97 return final_output.getDefiningOp();
98 }
99
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)100 StatusOr<llvm::DenseMap<int, Layout>> CumsumSPMDExpander::ComputeLayoutForward(
101 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
102 TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshEnclosingCluster(op));
103 TF_ASSIGN_OR_RETURN(int64_t axis_dim, GetAxisDimension(op));
104
105 if (input_layouts.find(0) == input_layouts.end())
106 return llvm::DenseMap<int, Layout>();
107
108 auto input_layout = input_layouts.lookup(0);
109 return llvm::DenseMap<int, Layout>(
110 {{0, input_layout.GetLayoutWithReducedDims({axis_dim},
111 /*keep_dims=*/true)}});
112 }
113
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)114 StatusOr<llvm::DenseMap<int, Layout>> CumsumSPMDExpander::ComputeLayoutBackward(
115 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
116 TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshEnclosingCluster(op));
117 TF_ASSIGN_OR_RETURN(int64_t axis_dim, GetAxisDimension(op));
118
119 if (output_layouts.find(0) == output_layouts.end())
120 return llvm::DenseMap<int, Layout>();
121 auto output_layout = output_layouts.lookup(0);
122 return llvm::DenseMap<int, Layout>(
123 {{0, output_layout.GetLayoutWithReducedDims({axis_dim},
124 /*keep_dims=*/true)}});
125 }
126
127 } // namespace dtensor
128 } // namespace tensorflow
129