1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <numeric>
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/util/matmul_bcast.h"
27 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
28 #include "tensorflow/lite/toco/model.h"
29 #include "tensorflow/lite/toco/tooling_util.h"
30
31 namespace toco {
32 namespace {
33
ToInlinedVector(const std::vector<int> & vec)34 absl::InlinedVector<int64_t, 4> ToInlinedVector(const std::vector<int>& vec) {
35 return absl::InlinedVector<int64_t, 4>(vec.begin(), vec.end());
36 }
37
SliceInput(const std::string & input,const std::string & base_name,const std::string & input_name,const int batch_size,const Array & input_array,Model * model,std::vector<std::unique_ptr<Operator>>::iterator * tail_it)38 std::vector<std::string> SliceInput(
39 const std::string& input, const std::string& base_name,
40 const std::string& input_name, const int batch_size,
41 const Array& input_array, Model* model,
42 std::vector<std::unique_ptr<Operator>>::iterator* tail_it) {
43 int rank = input_array.shape().dimensions_count();
44 int num_rows = input_array.shape().dims(rank - 2);
45 int num_cols = input_array.shape().dims(rank - 1);
46 // Reshape to rank-3 Tensor with first dimension as the batch size.
47 auto* reshape_op = new TensorFlowReshapeOperator;
48 reshape_op->inputs = {
49 input,
50 CreateInt32Array(model, absl::StrCat(base_name, "/reshape_a/shape"),
51 {batch_size, num_rows, num_cols})};
52 reshape_op->outputs = {AvailableArrayName(
53 *model, absl::StrCat(base_name, "/reshape_", input_name, "/reshape"))};
54 auto& reshape_op_output = model->GetOrCreateArray(reshape_op->outputs[0]);
55 reshape_op_output.data_type = input_array.data_type;
56 *tail_it = model->operators.emplace(*tail_it, reshape_op) + 1;
57
58 // Slice along each batch index and remember the slice output for future use.
59 std::vector<std::string> slice_outputs;
60 slice_outputs.reserve(batch_size);
61 for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
62 std::string batch_name =
63 absl::StrCat(base_name, "_b", batch_idx, "/slice_", input_name);
64 auto* slice_op = new SliceOperator;
65 slice_op->inputs = {
66 reshape_op->outputs[0],
67 CreateInt32Array(model, absl::StrCat(batch_name, "/slice/begin"),
68 {batch_idx, 0, 0}),
69 CreateInt32Array(model, absl::StrCat(batch_name, "/slice/size"),
70 {1, num_rows, num_cols})};
71 slice_op->outputs = {
72 AvailableArrayName(*model, absl::StrCat(batch_name, "/slice"))};
73 auto& slice_op_output = model->GetOrCreateArray(slice_op->outputs[0]);
74 slice_op_output.data_type = input_array.data_type;
75 *tail_it = model->operators.emplace(*tail_it, slice_op) + 1;
76
77 // Reshape to rank-2: [1, num_rows, num_cols] -> [num_rows, num_cols].
78 auto* slice_reshape_op = new TensorFlowReshapeOperator;
79 slice_reshape_op->inputs = {
80 slice_op->outputs[0],
81 CreateInt32Array(model, absl::StrCat(batch_name, "/reshape/shape"),
82 {num_rows, num_cols})};
83 slice_reshape_op->outputs = {
84 AvailableArrayName(*model, absl::StrCat(batch_name, "/reshape"))};
85 auto& slice_reshape_op_output =
86 model->GetOrCreateArray(slice_reshape_op->outputs[0]);
87 slice_reshape_op_output.data_type = input_array.data_type;
88 *tail_it = model->operators.emplace(*tail_it, slice_reshape_op) + 1;
89
90 slice_outputs.push_back(slice_reshape_op->outputs[0]);
91 }
92 return slice_outputs;
93 }
94
GetTransposePerm(const Array & input_array)95 std::vector<int32> GetTransposePerm(const Array& input_array) {
96 const int32_t dims = input_array.shape().dimensions_count();
97 std::vector<int32> perm_array_val(dims);
98 for (int32_t i = 0; i < dims; ++i) {
99 perm_array_val[i] = i;
100 }
101 perm_array_val[dims - 2] = dims - 1;
102 perm_array_val[dims - 1] = dims - 2;
103 return perm_array_val;
104 }
105
GetTransposeShape(const Shape & input_shape,const std::vector<int32> & perm_array_val)106 std::vector<int32> GetTransposeShape(const Shape& input_shape,
107 const std::vector<int32>& perm_array_val) {
108 const int32_t dims = input_shape.dimensions_count();
109 std::vector<int32> output_shape(dims);
110 for (int32_t i = 0; i < dims; ++i) {
111 output_shape[i] = input_shape.dims(perm_array_val[i]);
112 }
113 return output_shape;
114 }
115
TransposeInput(const std::string & input,Model * model)116 TransposeOperator* TransposeInput(const std::string& input, Model* model) {
117 const auto& input_array = model->GetArray(input);
118 const auto perm_array = GetTransposePerm(input_array);
119 const std::string perm_array_name = CreateInt32Array(
120 model, AvailableArrayName(*model, input + "/transpose/perm"), perm_array);
121 auto* transpose_op = new TransposeOperator;
122 transpose_op->inputs = {input, perm_array_name};
123 transpose_op->outputs = {AvailableArrayName(*model, input + "/transpose")};
124 auto& transpose_array = model->GetOrCreateArray(transpose_op->outputs[0]);
125 *transpose_array.mutable_shape()->mutable_dims() =
126 GetTransposeShape(input_array.shape(), perm_array);
127 model->GetOrCreateArray(transpose_op->outputs[0]);
128 return transpose_op;
129 }
130
131 } // namespace
132
133 // Unrolls a BatchMatMul on the batch dimension.
134 // We need to slice each batch out of the inputs, matmul them individually, then
135 // stack them all back together at the end.
Run(Model * model,std::size_t op_index,bool * modified)136 ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
137 bool* modified) {
138 *modified = false;
139 auto batch_op_it = model->operators.begin() + op_index;
140 if (batch_op_it->get()->type != OperatorType::kBatchMatMul) {
141 return ::tensorflow::OkStatus();
142 }
143 const auto* batch_op =
144 static_cast<const BatchMatMulOperator*>(batch_op_it->get());
145 auto& tail_it = batch_op_it;
146
147 std::string input_lhs = batch_op->inputs[0];
148 std::string input_rhs = batch_op->inputs[1];
149 const auto& input_lhs_array = model->GetArray(input_lhs);
150 const auto& input_rhs_array = model->GetArray(input_rhs);
151 if (!input_lhs_array.has_shape() || !input_rhs_array.has_shape())
152 return ::tensorflow::OkStatus();
153
154 // Transpose LHS input if necessary.
155 if (batch_op->adj_x) {
156 TransposeOperator* transpose_op = TransposeInput(input_lhs, model);
157 tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
158 input_lhs = transpose_op->outputs[0];
159 }
160 const auto& input_array_a = model->GetArray(input_lhs);
161
162 // Transpose RHS input if necessary.
163 if (batch_op->adj_y) {
164 TransposeOperator* transpose_op = TransposeInput(input_rhs, model);
165 tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
166 input_rhs = transpose_op->outputs[0];
167 }
168 const auto& input_array_b = model->GetArray(input_rhs);
169
170 // Ensure that input ranks are at least 2 and batch shapes are broadcastable.
171 const int dims_a = input_array_a.shape().dimensions_count();
172 const int dims_b = input_array_b.shape().dimensions_count();
173 CHECK_GE(dims_a, 2) << "First input must have rank >= 2";
174 CHECK_GE(dims_b, 2) << "Second input must have rank >= 2";
175
176 ::tensorflow::MatMulBCast bcast(
177 ToInlinedVector(input_array_a.shape().dims()),
178 ToInlinedVector(input_array_b.shape().dims()));
179 CHECK(bcast.IsValid()) << "Input batch dimensions must be broadcastable";
180
181 CHECK_EQ(input_array_a.shape().dims(dims_a - 1),
182 input_array_b.shape().dims(dims_b - 2))
183 << "Input dimensions must be compatible for multiplication. shape a = ["
184 << absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = ["
185 << absl::StrJoin(input_array_b.shape().dims(), ", ") << "]";
186
187 if (dims_a == 2 && dims_b == 2) {
188 // This is really just a MatMul.
189 AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator",
190 LogName(*batch_op));
191 auto* matmul_op = new TensorFlowMatMulOperator;
192 matmul_op->inputs = {input_lhs, input_rhs};
193 matmul_op->outputs = batch_op->outputs;
194 model->operators.emplace(tail_it, matmul_op);
195 DeleteOpAndArrays(model, batch_op);
196 *modified = true;
197 return ::tensorflow::OkStatus();
198 }
199 AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op),
200 bcast.output_batch_size());
201 std::string base_name = std::string(batch_op->outputs[0]);
202
203 // Compute slices for each batch in the LHS and RHS.
204 std::vector<std::string> slice_a_outputs =
205 SliceInput(input_lhs, base_name, "a", bcast.x_batch_size(), input_array_a,
206 model, &tail_it);
207 std::vector<std::string> slice_b_outputs =
208 SliceInput(input_rhs, base_name, "b", bcast.y_batch_size(), input_array_b,
209 model, &tail_it);
210
211 // Compute (single batch) MatMul for each output batch. The MatMul outputs are
212 // then packed together into one output Tensor.
213 std::vector<std::string> pack_inputs;
214 for (int64_t batch_idx = 0; batch_idx < bcast.output_batch_size();
215 ++batch_idx) {
216 std::string batch_name =
217 absl::StrCat(batch_op->outputs[0], "_b", batch_idx);
218 const int a_batch_idx = bcast.IsBroadcastingRequired()
219 ? bcast.x_batch_indices()[batch_idx]
220 : batch_idx;
221 const int b_batch_idx = bcast.IsBroadcastingRequired()
222 ? bcast.y_batch_indices()[batch_idx]
223 : batch_idx;
224 auto* matmul_op = new TensorFlowMatMulOperator;
225 matmul_op->inputs = {slice_a_outputs[a_batch_idx],
226 slice_b_outputs[b_batch_idx]};
227 matmul_op->outputs = {AvailableArrayName(*model, batch_name)};
228 auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]);
229 matmul_op_output.data_type = input_array_a.data_type;
230 tail_it = model->operators.emplace(tail_it, matmul_op) + 1;
231
232 // Add to stack.
233 pack_inputs.push_back(matmul_op->outputs[0]);
234 }
235
236 // Combine the result of each individual MatMul into a rank-3 Tensor.
237 auto* pack_op = new PackOperator;
238 pack_op->inputs = pack_inputs;
239 pack_op->outputs = {AvailableArrayName(*model, base_name + "/pack")};
240 auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]);
241 pack_op_output.data_type = input_array_a.data_type;
242 pack_op->axis = 0;
243 pack_op->values_count = pack_inputs.size();
244 tail_it = model->operators.emplace(tail_it, pack_op) + 1;
245
246 // Reshape the rank-3 Tensor into the correct output shape.
247 const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
248 std::vector<int> result_shape;
249 // Explicitly cast 64-bit sizes to int in order to avoid MSVC warnings.
250 std::transform(result_batch_shape.begin(), result_batch_shape.end(),
251 std::back_inserter(result_shape),
252 [](const int64_t dim) { return static_cast<int>(dim); });
253 result_shape.push_back(input_array_a.shape().dims(dims_a - 2));
254 result_shape.push_back(input_array_b.shape().dims(dims_b - 1));
255
256 auto* reshape_result_op = new TensorFlowReshapeOperator;
257 reshape_result_op->inputs = {
258 pack_op->outputs[0],
259 CreateInt32Array(model, base_name + "/reshape_out/shape", result_shape)};
260 reshape_result_op->outputs = {batch_op->outputs[0]};
261 model->operators.emplace(tail_it, reshape_result_op);
262
263 DeleteOpAndArrays(model, batch_op);
264 *modified = true;
265 return ::tensorflow::OkStatus();
266 }
267
268 } // namespace toco
269