xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <numeric>
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/util/matmul_bcast.h"
27 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
28 #include "tensorflow/lite/toco/model.h"
29 #include "tensorflow/lite/toco/tooling_util.h"
30 
31 namespace toco {
32 namespace {
33 
ToInlinedVector(const std::vector<int> & vec)34 absl::InlinedVector<int64_t, 4> ToInlinedVector(const std::vector<int>& vec) {
35   return absl::InlinedVector<int64_t, 4>(vec.begin(), vec.end());
36 }
37 
SliceInput(const std::string & input,const std::string & base_name,const std::string & input_name,const int batch_size,const Array & input_array,Model * model,std::vector<std::unique_ptr<Operator>>::iterator * tail_it)38 std::vector<std::string> SliceInput(
39     const std::string& input, const std::string& base_name,
40     const std::string& input_name, const int batch_size,
41     const Array& input_array, Model* model,
42     std::vector<std::unique_ptr<Operator>>::iterator* tail_it) {
43   int rank = input_array.shape().dimensions_count();
44   int num_rows = input_array.shape().dims(rank - 2);
45   int num_cols = input_array.shape().dims(rank - 1);
46   // Reshape to rank-3 Tensor with first dimension as the batch size.
47   auto* reshape_op = new TensorFlowReshapeOperator;
48   reshape_op->inputs = {
49       input,
50       CreateInt32Array(model, absl::StrCat(base_name, "/reshape_a/shape"),
51                        {batch_size, num_rows, num_cols})};
52   reshape_op->outputs = {AvailableArrayName(
53       *model, absl::StrCat(base_name, "/reshape_", input_name, "/reshape"))};
54   auto& reshape_op_output = model->GetOrCreateArray(reshape_op->outputs[0]);
55   reshape_op_output.data_type = input_array.data_type;
56   *tail_it = model->operators.emplace(*tail_it, reshape_op) + 1;
57 
58   // Slice along each batch index and remember the slice output for future use.
59   std::vector<std::string> slice_outputs;
60   slice_outputs.reserve(batch_size);
61   for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
62     std::string batch_name =
63         absl::StrCat(base_name, "_b", batch_idx, "/slice_", input_name);
64     auto* slice_op = new SliceOperator;
65     slice_op->inputs = {
66         reshape_op->outputs[0],
67         CreateInt32Array(model, absl::StrCat(batch_name, "/slice/begin"),
68                          {batch_idx, 0, 0}),
69         CreateInt32Array(model, absl::StrCat(batch_name, "/slice/size"),
70                          {1, num_rows, num_cols})};
71     slice_op->outputs = {
72         AvailableArrayName(*model, absl::StrCat(batch_name, "/slice"))};
73     auto& slice_op_output = model->GetOrCreateArray(slice_op->outputs[0]);
74     slice_op_output.data_type = input_array.data_type;
75     *tail_it = model->operators.emplace(*tail_it, slice_op) + 1;
76 
77     // Reshape to rank-2: [1, num_rows, num_cols] -> [num_rows, num_cols].
78     auto* slice_reshape_op = new TensorFlowReshapeOperator;
79     slice_reshape_op->inputs = {
80         slice_op->outputs[0],
81         CreateInt32Array(model, absl::StrCat(batch_name, "/reshape/shape"),
82                          {num_rows, num_cols})};
83     slice_reshape_op->outputs = {
84         AvailableArrayName(*model, absl::StrCat(batch_name, "/reshape"))};
85     auto& slice_reshape_op_output =
86         model->GetOrCreateArray(slice_reshape_op->outputs[0]);
87     slice_reshape_op_output.data_type = input_array.data_type;
88     *tail_it = model->operators.emplace(*tail_it, slice_reshape_op) + 1;
89 
90     slice_outputs.push_back(slice_reshape_op->outputs[0]);
91   }
92   return slice_outputs;
93 }
94 
GetTransposePerm(const Array & input_array)95 std::vector<int32> GetTransposePerm(const Array& input_array) {
96   const int32_t dims = input_array.shape().dimensions_count();
97   std::vector<int32> perm_array_val(dims);
98   for (int32_t i = 0; i < dims; ++i) {
99     perm_array_val[i] = i;
100   }
101   perm_array_val[dims - 2] = dims - 1;
102   perm_array_val[dims - 1] = dims - 2;
103   return perm_array_val;
104 }
105 
GetTransposeShape(const Shape & input_shape,const std::vector<int32> & perm_array_val)106 std::vector<int32> GetTransposeShape(const Shape& input_shape,
107                                      const std::vector<int32>& perm_array_val) {
108   const int32_t dims = input_shape.dimensions_count();
109   std::vector<int32> output_shape(dims);
110   for (int32_t i = 0; i < dims; ++i) {
111     output_shape[i] = input_shape.dims(perm_array_val[i]);
112   }
113   return output_shape;
114 }
115 
TransposeInput(const std::string & input,Model * model)116 TransposeOperator* TransposeInput(const std::string& input, Model* model) {
117   const auto& input_array = model->GetArray(input);
118   const auto perm_array = GetTransposePerm(input_array);
119   const std::string perm_array_name = CreateInt32Array(
120       model, AvailableArrayName(*model, input + "/transpose/perm"), perm_array);
121   auto* transpose_op = new TransposeOperator;
122   transpose_op->inputs = {input, perm_array_name};
123   transpose_op->outputs = {AvailableArrayName(*model, input + "/transpose")};
124   auto& transpose_array = model->GetOrCreateArray(transpose_op->outputs[0]);
125   *transpose_array.mutable_shape()->mutable_dims() =
126       GetTransposeShape(input_array.shape(), perm_array);
127   model->GetOrCreateArray(transpose_op->outputs[0]);
128   return transpose_op;
129 }
130 
131 }  // namespace
132 
133 // Unrolls a BatchMatMul on the batch dimension.
134 // We need to slice each batch out of the inputs, matmul them individually, then
135 // stack them all back together at the end.
Run(Model * model,std::size_t op_index,bool * modified)136 ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
137                                             bool* modified) {
138   *modified = false;
139   auto batch_op_it = model->operators.begin() + op_index;
140   if (batch_op_it->get()->type != OperatorType::kBatchMatMul) {
141     return ::tensorflow::OkStatus();
142   }
143   const auto* batch_op =
144       static_cast<const BatchMatMulOperator*>(batch_op_it->get());
145   auto& tail_it = batch_op_it;
146 
147   std::string input_lhs = batch_op->inputs[0];
148   std::string input_rhs = batch_op->inputs[1];
149   const auto& input_lhs_array = model->GetArray(input_lhs);
150   const auto& input_rhs_array = model->GetArray(input_rhs);
151   if (!input_lhs_array.has_shape() || !input_rhs_array.has_shape())
152     return ::tensorflow::OkStatus();
153 
154   // Transpose LHS input if necessary.
155   if (batch_op->adj_x) {
156     TransposeOperator* transpose_op = TransposeInput(input_lhs, model);
157     tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
158     input_lhs = transpose_op->outputs[0];
159   }
160   const auto& input_array_a = model->GetArray(input_lhs);
161 
162   // Transpose RHS input if necessary.
163   if (batch_op->adj_y) {
164     TransposeOperator* transpose_op = TransposeInput(input_rhs, model);
165     tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
166     input_rhs = transpose_op->outputs[0];
167   }
168   const auto& input_array_b = model->GetArray(input_rhs);
169 
170   // Ensure that input ranks are at least 2 and batch shapes are broadcastable.
171   const int dims_a = input_array_a.shape().dimensions_count();
172   const int dims_b = input_array_b.shape().dimensions_count();
173   CHECK_GE(dims_a, 2) << "First input must have rank >= 2";
174   CHECK_GE(dims_b, 2) << "Second input must have rank >= 2";
175 
176   ::tensorflow::MatMulBCast bcast(
177       ToInlinedVector(input_array_a.shape().dims()),
178       ToInlinedVector(input_array_b.shape().dims()));
179   CHECK(bcast.IsValid()) << "Input batch dimensions must be broadcastable";
180 
181   CHECK_EQ(input_array_a.shape().dims(dims_a - 1),
182            input_array_b.shape().dims(dims_b - 2))
183       << "Input dimensions must be compatible for multiplication. shape a = ["
184       << absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = ["
185       << absl::StrJoin(input_array_b.shape().dims(), ", ") << "]";
186 
187   if (dims_a == 2 && dims_b == 2) {
188     // This is really just a MatMul.
189     AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator",
190                 LogName(*batch_op));
191     auto* matmul_op = new TensorFlowMatMulOperator;
192     matmul_op->inputs = {input_lhs, input_rhs};
193     matmul_op->outputs = batch_op->outputs;
194     model->operators.emplace(tail_it, matmul_op);
195     DeleteOpAndArrays(model, batch_op);
196     *modified = true;
197     return ::tensorflow::OkStatus();
198   }
199   AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op),
200               bcast.output_batch_size());
201   std::string base_name = std::string(batch_op->outputs[0]);
202 
203   // Compute slices for each batch in the LHS and RHS.
204   std::vector<std::string> slice_a_outputs =
205       SliceInput(input_lhs, base_name, "a", bcast.x_batch_size(), input_array_a,
206                  model, &tail_it);
207   std::vector<std::string> slice_b_outputs =
208       SliceInput(input_rhs, base_name, "b", bcast.y_batch_size(), input_array_b,
209                  model, &tail_it);
210 
211   // Compute (single batch) MatMul for each output batch. The MatMul outputs are
212   // then packed together into one output Tensor.
213   std::vector<std::string> pack_inputs;
214   for (int64_t batch_idx = 0; batch_idx < bcast.output_batch_size();
215        ++batch_idx) {
216     std::string batch_name =
217         absl::StrCat(batch_op->outputs[0], "_b", batch_idx);
218     const int a_batch_idx = bcast.IsBroadcastingRequired()
219                                 ? bcast.x_batch_indices()[batch_idx]
220                                 : batch_idx;
221     const int b_batch_idx = bcast.IsBroadcastingRequired()
222                                 ? bcast.y_batch_indices()[batch_idx]
223                                 : batch_idx;
224     auto* matmul_op = new TensorFlowMatMulOperator;
225     matmul_op->inputs = {slice_a_outputs[a_batch_idx],
226                          slice_b_outputs[b_batch_idx]};
227     matmul_op->outputs = {AvailableArrayName(*model, batch_name)};
228     auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]);
229     matmul_op_output.data_type = input_array_a.data_type;
230     tail_it = model->operators.emplace(tail_it, matmul_op) + 1;
231 
232     // Add to stack.
233     pack_inputs.push_back(matmul_op->outputs[0]);
234   }
235 
236   // Combine the result of each individual MatMul into a rank-3 Tensor.
237   auto* pack_op = new PackOperator;
238   pack_op->inputs = pack_inputs;
239   pack_op->outputs = {AvailableArrayName(*model, base_name + "/pack")};
240   auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]);
241   pack_op_output.data_type = input_array_a.data_type;
242   pack_op->axis = 0;
243   pack_op->values_count = pack_inputs.size();
244   tail_it = model->operators.emplace(tail_it, pack_op) + 1;
245 
246   // Reshape the rank-3 Tensor into the correct output shape.
247   const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
248   std::vector<int> result_shape;
249   // Explicitly cast 64-bit sizes to int in order to avoid MSVC warnings.
250   std::transform(result_batch_shape.begin(), result_batch_shape.end(),
251                  std::back_inserter(result_shape),
252                  [](const int64_t dim) { return static_cast<int>(dim); });
253   result_shape.push_back(input_array_a.shape().dims(dims_a - 2));
254   result_shape.push_back(input_array_b.shape().dims(dims_b - 1));
255 
256   auto* reshape_result_op = new TensorFlowReshapeOperator;
257   reshape_result_op->inputs = {
258       pack_op->outputs[0],
259       CreateInt32Array(model, base_name + "/reshape_out/shape", result_shape)};
260   reshape_result_op->outputs = {batch_op->outputs[0]};
261   model->operators.emplace(tail_it, reshape_result_op);
262 
263   DeleteOpAndArrays(model, batch_op);
264   *modified = true;
265   return ::tensorflow::OkStatus();
266 }
267 
268 }  // namespace toco
269