xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/cumsum_spmd_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/expansions/cumsum_spmd_expander.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <string>
21 #include <utility>
22 
23 #include "llvm/Support/FormatVariadic.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/dtensor/cc/dstatus.h"
26 #include "tensorflow/dtensor/mlir/collectives.h"
27 #include "tensorflow/dtensor/mlir/layout_parsing.h"
28 #include "tensorflow/dtensor/mlir/op_utils.h"
29 #include "tensorflow/dtensor/mlir/shape_utils.h"
30 #include "tensorflow/dtensor/mlir/value_utils.h"
31 
32 namespace tensorflow {
33 namespace dtensor {
34 
35 namespace {
36 
37 // Extract `axis` tensor from Cumsum op and return it's positive value, since
38 // it can be a negative index.
GetAxisDimension(mlir::Operation * op)39 StatusOr<int64_t> GetAxisDimension(mlir::Operation* op) {
40   auto cumsum = llvm::dyn_cast<mlir::TF::CumsumOp>(op);
41   if (cumsum == nullptr) {
42     return errors::Internal(
43         absl::StrCat("Expected Cumsum op but got : ", OpName(op)).c_str());
44   }
45   TF_ASSIGN_OR_RETURN(int64_t axis_dim,
46                       ExtractConstIntFromValue(cumsum.axis()));
47   int64_t tensor_rank = ValueRank(cumsum.x());
48   // Axis can be in range [-tensor_rank, tensor_rank), so we add tensor_rank
49   // to wrap it around.
50   if (axis_dim >= -tensor_rank && axis_dim < 0) {
51     axis_dim += tensor_rank;
52   } else if (axis_dim < -tensor_rank || axis_dim >= tensor_rank) {
53     return errors::InvalidArgument(
54         "Invalid axis; expected a value in [-tensor_rank, tensor_rank)");
55   }
56   return axis_dim;
57 }
58 
59 }  // namespace
60 
ExpandOp(mlir::Operation * op)61 StatusOr<mlir::Operation*> CumsumSPMDExpander::ExpandOp(mlir::Operation* op) {
62   StatusOr<int64_t> axis_dim = GetAxisDimension(op);
63   if (!axis_dim.ok()) return axis_dim.status();
64 
65   TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
66   assert(output_layout);
67 
68   // Our intermediate computation layout is the output layout with
69   // the axis dimension replicated. So set both the operand and output layout
70   // to this intermediate layout.
71   Layout intermediate_layout = output_layout->GetLayoutWithReducedDims(
72       {axis_dim.ValueOrDie()}, /*keep_dims=*/true);
73 
74   // Relayout operand to intermediate layout.
75   mlir::OpBuilder builder(op);
76   const auto operand = op->getOperand(0);
77   TF_ASSIGN_OR_RETURN(auto operand_layout, ExtractLayoutFromOperand(operand));
78   if (!operand_layout)
79     return errors::InvalidArgument(
80         "input layout of Cumsum op must be known before SPMD "
81         "expansion.");
82 
83   TF_ASSIGN_OR_RETURN(
84       const auto new_operand,
85       EmitRelayout(operand, operand_layout.value(), intermediate_layout));
86   op->setOperand(0, new_operand);
87 
88   op = InferSPMDExpandedLocalShape(op);
89 
90   // Relayout output to intermediate layout.
91   llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
92   builder.setInsertionPointAfter(op);
93   TF_ASSIGN_OR_RETURN(auto final_output,
94                       EmitRelayout(op->getOpResult(0), intermediate_layout,
95                                    output_layout.value(), &newly_created_ops));
96   op->getOpResult(0).replaceAllUsesExcept(final_output, newly_created_ops);
97   return final_output.getDefiningOp();
98 }
99 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)100 StatusOr<llvm::DenseMap<int, Layout>> CumsumSPMDExpander::ComputeLayoutForward(
101     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
102   TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshEnclosingCluster(op));
103   TF_ASSIGN_OR_RETURN(int64_t axis_dim, GetAxisDimension(op));
104 
105   if (input_layouts.find(0) == input_layouts.end())
106     return llvm::DenseMap<int, Layout>();
107 
108   auto input_layout = input_layouts.lookup(0);
109   return llvm::DenseMap<int, Layout>(
110       {{0, input_layout.GetLayoutWithReducedDims({axis_dim},
111                                                  /*keep_dims=*/true)}});
112 }
113 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)114 StatusOr<llvm::DenseMap<int, Layout>> CumsumSPMDExpander::ComputeLayoutBackward(
115     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
116   TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshEnclosingCluster(op));
117   TF_ASSIGN_OR_RETURN(int64_t axis_dim, GetAxisDimension(op));
118 
119   if (output_layouts.find(0) == output_layouts.end())
120     return llvm::DenseMap<int, Layout>();
121   auto output_layout = output_layouts.lookup(0);
122   return llvm::DenseMap<int, Layout>(
123       {{0, output_layout.GetLayoutWithReducedDims({axis_dim},
124                                                   /*keep_dims=*/true)}});
125 }
126 
127 }  // namespace dtensor
128 }  // namespace tensorflow
129