xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cmath>
17 #include <iterator>
18 #include <memory>
19 #include <numeric>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
27 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
28 #include "tensorflow/lite/toco/model.h"
29 #include "tensorflow/lite/toco/tooling_util.h"
30 
31 namespace toco {
32 
33 namespace {
34 
ComputeConvSizes(const Shape & input_shape,int output_depth,int kwidth,int kheight,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,PaddingType padding_type,Shape * output_shape,FixedPadding * fixed_padding)35 void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
36                       int kheight, int stride_width, int stride_height,
37                       int dilation_width_factor, int dilation_height_factor,
38                       PaddingType padding_type, Shape* output_shape,
39                       FixedPadding* fixed_padding) {
40   const int input_width = input_shape.dims(2);
41   const int input_height = input_shape.dims(1);
42   const int batch = input_shape.dims(0);
43 
44   CHECK_GE(input_width, 1);
45   CHECK_GE(input_height, 1);
46   CHECK_GE(batch, 1);
47   CHECK_GE(kwidth, 1);
48   CHECK_GE(kheight, 1);
49   CHECK_GE(stride_width, 1);
50   CHECK_GE(stride_height, 1);
51   CHECK_GE(dilation_width_factor, 1);
52   CHECK_GE(dilation_height_factor, 1);
53 
54   int dilated_kwidth = dilation_width_factor * (kwidth - 1) + 1;
55   int dilated_kheight = dilation_height_factor * (kheight - 1) + 1;
56 
57   int output_height = 0;
58   int output_width = 0;
59   if (padding_type == PaddingType::kValid) {
60     output_height =
61         (input_height + stride_height - dilated_kheight) / stride_height;
62     output_width = (input_width + stride_width - dilated_kwidth) / stride_width;
63   } else if (padding_type == PaddingType::kSame) {
64     output_height = (input_height + stride_height - 1) / stride_height;
65     output_width = (input_width + stride_width - 1) / stride_width;
66   } else {
67     LOG(FATAL) << "Only supporting SAME or VALID padding";
68   }
69 
70   fixed_padding->height = std::max(0, ((output_height - 1) * stride_height +
71                                        dilated_kheight - input_height) /
72                                           2);
73   fixed_padding->width = std::max(
74       0,
75       ((output_width - 1) * stride_width + dilated_kwidth - input_width) / 2);
76 
77   // Actually had to debug a situation where those were negative due to bad
78   // propagation of placeholder -1 sizes in TensorFlowReshape.
79   CHECK_GT(output_width, 0);
80   CHECK_GT(output_height, 0);
81   output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
82 }
83 
ComputeBinaryOperatorOutputSize(const Shape & input_shape_x,const Shape & input_shape_y,Array * output_array)84 void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x,
85                                      const Shape& input_shape_y,
86                                      Array* output_array) {
87   // This matches the code in BroadcastBinaryOpShapeFn from tensorflow.
88   // It zips together the two input shapes and pads with 1 to make them the
89   // same length. For each dimension we broadcast if either dimension is 1 and
90   // otherwise expect them to match.
91   int rank_x = input_shape_x.dimensions_count();
92   int rank_y = input_shape_y.dimensions_count();
93   int rank_out = std::max(rank_x, rank_y);
94   std::vector<int>* dims_out = output_array->mutable_shape()->mutable_dims();
95   dims_out->clear();
96   dims_out->reserve(rank_out);
97   for (int i = 0; i < rank_out; ++i) {
98     int dim_x = i < (rank_out - rank_x)
99                     ? 1
100                     : input_shape_x.dims(i - (rank_out - rank_x));
101     bool dim_y_is_one = i < (rank_out - rank_y);
102     int dim_y = dim_y_is_one ? 1 : input_shape_y.dims(i - (rank_out - rank_y));
103     if (dim_x == -1 || dim_y == -1) {
104       // One or both dimensions is unknown.
105       QCHECK(false) << "Shapes must be specified";
106     } else if (dim_x == 1 || dim_y == 1) {
107       // Broadcast one dimension to the other that is 1.
108       if (dim_x == 1 && !dim_y_is_one) {
109         // Broadcast dim_y to dim_x (1).
110         dims_out->push_back(dim_y);
111       } else {
112         // Broadcast dim_x to dim_y (1).
113         DCHECK_EQ(dim_y, 1);
114         dims_out->push_back(dim_x);
115       }
116     } else {
117       // Expect the dimensions to match.
118       CHECK_EQ(dim_x, dim_y) << "Dimensions must match";
119       dims_out->push_back(dim_x);
120     }
121   }
122   CHECK(output_array->has_shape());
123 }
124 
ProcessConvOperator(Model * model,ConvOperator * op)125 void ProcessConvOperator(Model* model, ConvOperator* op) {
126   const auto& input_array = model->GetArray(op->inputs[0]);
127   // Yield until input dims have been resolved.
128   if (!input_array.has_shape()) {
129     return;
130   }
131   const auto& input_shape = input_array.shape();
132   CHECK(input_shape.dimensions_count() == 4)
133       << "Conv ops require 4D inputs. Input array \"" << op->inputs[0]
134       << "\" is " << input_shape.dimensions_count() << "D.";
135 
136   const auto& weights_array = model->GetArray(op->inputs[1]);
137   // Yield until weights dims have been resolved.
138   if (!weights_array.has_shape()) {
139     return;
140   }
141   const auto& weights_shape = weights_array.shape();
142   CHECK_EQ(weights_shape.dimensions_count(), 4);
143 
144   auto& output_array = model->GetArray(op->outputs[0]);
145   const int output_depth = weights_shape.dims(0);
146   const int kheight = weights_shape.dims(1);
147   const int kwidth = weights_shape.dims(2);
148   ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
149                    op->stride_height, op->dilation_width_factor,
150                    op->dilation_height_factor, op->padding.type,
151                    output_array.mutable_shape(),
152                    &op->padding.GetOrCreateFixedPadding());
153   CHECK_EQ(output_array.shape().dimensions_count(), 4);
154 
155   // Set im2col array dimensions if there is one.
156   if (op->outputs.size() == 2) {
157     const auto& output_shape = output_array.shape();
158     const int input_depth = weights_shape.dims(3);
159     auto& im2col_array = model->GetArray(op->outputs[1]);
160     im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
161                                   output_shape.dims(2),
162                                   input_depth * kheight * kwidth});
163   }
164 }
165 
ProcessTransposeConvOperator(Model * model,TransposeConvOperator * op)166 void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
167   // TransposeConv is unique in that it is specifically given the output shape
168   // as a 1D array on it's 1st input. Resolving the output shape is as easy
169   // as waiting for this input to be resolved. However, we also have to
170   // calculate the padding which requires the weights shape.
171 
172   // SPECIFIED OUTPUT SHAPE
173   // The below is the specified, or prescribed output shape, _given_ to the
174   // operator as an input.
175   auto& specified_output_shape_array =
176       model->GetArray(op->inputs[TransposeConvOperator::OUTPUT_SHAPE]);
177   if (!specified_output_shape_array.has_shape() ||
178       !specified_output_shape_array.buffer) {
179     // Yield until the specified output shape is resolved as a constant
180     return;
181   }
182 
183   CHECK(specified_output_shape_array.data_type == ArrayDataType::kInt32)
184       << "TransposeConv output_shape must be int32";
185 
186   CHECK(specified_output_shape_array.shape().dimensions_count() == 1 &&
187         specified_output_shape_array.shape().dims(0) == 4)
188       << "TransposeConv requires a 1D, 4 element array on it's 0th input "
189          "specifying the output shape. \""
190       << op->inputs[TransposeConvOperator::OUTPUT_SHAPE] << "\" had shape "
191       << toco::ShapeToString(specified_output_shape_array.shape());
192 
193   // COMPUTE PADDING
194   // We require the weights shape to calculate padding.
195   const auto& weights_array =
196       model->GetArray(op->inputs[TransposeConvOperator::WEIGHTS]);
197   if (!weights_array.has_shape()) {
198     // Yield until weights dims have been resolved.
199     return;
200   }
201   const auto& weights_shape = weights_array.shape();
202   CHECK_EQ(weights_shape.dimensions_count(), 4)
203       << "TransposeConv weights must have 4 input dimensions. Input weights \""
204       << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
205       << toco::ShapeToString(weights_shape) << ".";
206 
207   // Compute padding
208   const int kheight = weights_shape.dims(1);
209   const int kwidth = weights_shape.dims(2);
210   op->padding.GetOrCreateFixedPadding();
211   if (op->padding.type == PaddingType::kValid) {
212     op->padding.fixed->height = 0;
213     op->padding.fixed->width = 0;
214   } else if (op->padding.type == PaddingType::kSame) {
215     op->padding.fixed->height = (kheight - 1) / 2;
216     op->padding.fixed->width = (kwidth - 1) / 2;
217   } else {
218     LOG(FATAL) << "TransposeConv only supports SAME or VALID padding";
219   }
220 
221   // VALIDATE some dimensions and set the output shape.
222   const auto& input_array =
223       model->GetArray(op->inputs[TransposeConvOperator::DATA_INPUT]);
224   if (!input_array.has_shape()) {
225     // Yield until input dims have been resolved.
226     return;
227   }
228   const auto& input_shape = input_array.shape();
229   CHECK_EQ(input_shape.dimensions_count(), 4)
230       << "TransposeConv input shape must have 4 dimensions. Input \""
231       << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
232       << toco::ShapeToString(weights_shape) << ".";
233   CHECK_EQ(input_shape.dims(3), weights_shape.dims(3))
234       << "Input shape depth and weight depth do not agree";
235 
236   // Set the output shape according to the specified output shape.
237   std::vector<int32> const& specified_output_shape =
238       specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
239   auto& output_array = model->GetArray(op->outputs[0]);
240   *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape;
241 
242   // Set im2col array dimensions if there is one.
243   if (op->outputs.size() == 2) {
244     const int input_depth = weights_shape.dims(3);
245     auto& im2col_array = model->GetArray(op->outputs[1]);
246     im2col_array.copy_shape(
247         Shape{specified_output_shape[0], specified_output_shape[1],
248               specified_output_shape[2], input_depth * kheight * kwidth});
249   }
250 }
251 
ProcessDepthwiseConvOperator(Model * model,DepthwiseConvOperator * op)252 void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
253   const auto& input_array = model->GetArray(op->inputs[0]);
254   // Yield until input dims have been resolved.
255   if (!input_array.has_shape()) {
256     return;
257   }
258   const auto& input_shape = input_array.shape();
259   CHECK_EQ(input_shape.dimensions_count(), 4);
260 
261   const auto& weights_array = model->GetArray(op->inputs[1]);
262   // Yield until weights dims have been resolved.
263   if (!weights_array.has_shape()) {
264     return;
265   }
266   const auto& weights_shape = weights_array.shape();
267   CHECK_EQ(weights_shape.dimensions_count(), 4);
268 
269   const std::string& output_name = op->outputs[0];
270   const int input_depth = input_shape.dims(3);
271   const int output_depth = weights_shape.dims(3);
272   // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
273   // instead it has to be inferred from the weights dims. However, once we are
274   // here, weights dims have already been converted to our own internal format,
275   // where the multiplier is no longer readily apparent. So instead we get it
276   // as the quotient of output and input depths. We only want to do that when
277   // depth_multiplier had the zero value: any other value should be checked
278   // as done by the next if() below.
279   if (!op->depth_multiplier) {
280     op->depth_multiplier = output_depth / input_depth;
281   }
282   CHECK_EQ(output_depth, input_depth * op->depth_multiplier)
283       << "input/output depths and depth_multiplier don't match";
284 
285   const int kheight = weights_shape.dims(1);
286   const int kwidth = weights_shape.dims(2);
287   ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
288                    op->stride_height, op->dilation_width_factor,
289                    op->dilation_height_factor, op->padding.type,
290                    model->GetArray(output_name).mutable_shape(),
291                    &op->padding.GetOrCreateFixedPadding());
292 }
293 
ProcessDepthToSpaceOperator(Model * model,DepthToSpaceOperator * op)294 void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
295   const auto& input_array = model->GetArray(op->inputs[0]);
296   // Yield until input dims have been resolved.
297   if (!input_array.has_shape()) {
298     return;
299   }
300   const auto& input_shape = input_array.shape();
301   CHECK_EQ(input_shape.dimensions_count(), 4);
302 
303   const std::string& output_name = op->outputs[0];
304   const int block_size = op->block_size;
305   CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
306   const int batch = input_shape.dims(0);
307   const int height = input_shape.dims(1);
308   const int width = input_shape.dims(2);
309   const int depth = input_shape.dims(3);
310   QCHECK_EQ(depth % (block_size * block_size), 0);
311 
312   model->GetArray(output_name)
313       .copy_shape(Shape({batch, height * block_size, width * block_size,
314                          depth / block_size / block_size}));
315 }
316 
ProcessSpaceToDepthOperator(Model * model,SpaceToDepthOperator * op)317 void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
318   const auto& input_array = model->GetArray(op->inputs[0]);
319   // Yield until input dims have been resolved.
320   if (!input_array.has_shape()) {
321     return;
322   }
323   const auto& input_shape = input_array.shape();
324   CHECK_EQ(input_shape.dimensions_count(), 4);
325 
326   const std::string& output_name = op->outputs[0];
327   const int block_size = op->block_size;
328   CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
329   const int batch = input_shape.dims(0);
330   const int height = input_shape.dims(1);
331   const int width = input_shape.dims(2);
332   const int depth = input_shape.dims(3);
333   QCHECK_EQ(width % block_size, 0);
334   QCHECK_EQ(height % block_size, 0);
335 
336   model->GetArray(output_name)
337       .copy_shape(Shape({batch, height / block_size, width / block_size,
338                          depth * block_size * block_size}));
339 }
340 
ProcessOpWithShapeInput(Model * model,Operator * op)341 void ProcessOpWithShapeInput(Model* model, Operator* op) {
342   CHECK_EQ(op->outputs.size(), 1);
343   auto& output_array = model->GetArray(op->outputs[0]);
344   if (output_array.has_shape()) {
345     // We have already run
346     return;
347   }
348 
349   auto& dims_array = model->GetArray(op->inputs[0]);
350   if (!dims_array.has_shape()) {
351     // Yield until dims shape been resolved.
352     return;
353   }
354   if (!dims_array.buffer) {
355     // Yield until the dims are constant
356     return;
357   }
358   CHECK(dims_array.data_type == ArrayDataType::kInt32) << "dims must be int32";
359   CHECK_LE(RequiredBufferSizeForShape(dims_array.shape()), 6)
360       << "dims vector can be no larger than 6 values";
361 
362   std::vector<int32> const& dims =
363       dims_array.GetBuffer<ArrayDataType::kInt32>().data;
364   *(output_array.mutable_shape()->mutable_dims()) = dims;
365 }
366 
ProcessFullyConnectedOperator(Model * model,FullyConnectedOperator * op)367 void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
368   const auto& input_array = model->GetArray(op->inputs[0]);
369   // Yield until input dims have been resolved.
370   if (!input_array.has_shape()) {
371     return;
372   }
373   const auto& input_shape = input_array.shape();
374   if (input_shape.dimensions_count() < 1) {
375     return;
376   }
377 
378   const auto& weights_array = model->GetArray(op->inputs[1]);
379   // Yield until weights dims have been resolved.
380   if (!weights_array.has_shape()) {
381     return;
382   }
383   const auto& weights_shape = weights_array.shape();
384 
385   const int weights_output_depth = weights_shape.dims(0);
386   CHECK_EQ(weights_shape.dimensions_count(), 2);
387 
388   const int input_overall_size = RequiredBufferSizeForShape(input_shape);
389   const int matmul_repeats = input_overall_size / weights_shape.dims(1);
390   CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);
391 
392   auto& output_array = model->GetArray(op->outputs[0]);
393   output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
394 }
395 
ProcessTensorFlowReshapeOperator(Model * model,TensorFlowReshapeOperator * op)396 void ProcessTensorFlowReshapeOperator(Model* model,
397                                       TensorFlowReshapeOperator* op) {
398   auto& output_array = model->GetArray(op->outputs[0]);
399   if (output_array.has_shape()) {
400     // We have already run
401     return;
402   }
403 
404   const auto& input_array = model->GetArray(op->inputs[0]);
405   if (!input_array.has_shape()) {
406     // Yield until input dims have been resolved.
407     return;
408   }
409   const auto& input_shape = input_array.shape();
410 
411   auto& shape_array = model->GetArray(op->inputs[1]);
412   if (!shape_array.has_shape()) {
413     // Yield until target_shape shape been resolved.
414     return;
415   }
416   if (!shape_array.buffer) {
417     // Yield until the target_shape is constant
418     return;
419   }
420   CHECK(shape_array.data_type == ArrayDataType::kInt32)
421       << "Reshape dims must be int32";
422 
423   // shape_data is the raw array of ints describing the shape
424   // in the TensorFlow node. We intentionally make a copy here, rather than
425   // modify wildcards in-place below, because in some graphs, the same shape
426   // array with a wildcard may be referenced from multiple Reshape nodes, where
427   // the wildcard needs to resolved to distinct values.
428   std::vector<int32> shape_data =
429       shape_array.GetBuffer<ArrayDataType::kInt32>().data;
430   // The Reshape shape may have a wildcard dim, encoded as -1.
431   bool has_wildcard = false;
432   int wildcard_index = 0;
433   int product_non_wildcard_dims = 1;
434   for (size_t i = 0; i < shape_data.size(); i++) {
435     if (shape_data[i] == -1) {
436       CHECK(!has_wildcard);
437       has_wildcard = true;
438       wildcard_index = i;
439     } else {
440       product_non_wildcard_dims *= shape_data[i];
441     }
442   }
443 
444   const int input_flat_size = RequiredBufferSizeForShape(input_shape);
445   if (has_wildcard) {
446     CHECK_GE(input_flat_size, product_non_wildcard_dims)
447         << "Array not large enough to fill the requested dimensions for "
448            "Reshape op with output \""
449         << op->outputs[0] << "\". Are your input shapes correct?";
450     shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
451   }
452 
453   if (shape_data.size() == 1 && shape_data[0] == 0) {
454     // We have reshaped a scalar, so preserve as a scalar.
455     shape_data.clear();
456   }
457 
458   auto& output_shape = *output_array.mutable_shape();
459   *output_shape.mutable_dims() = shape_data;
460   CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape))
461       << "Input cannot be reshaped to requested dimensions for Reshape op with "
462          "output \""
463       << op->outputs[0] << "\". Are your input shapes correct?";
464 }
465 
ProcessSimpleOperator(Model * model,Operator * op,int input_index)466 void ProcessSimpleOperator(Model* model, Operator* op, int input_index) {
467   const auto& input_array = model->GetArray(op->inputs[input_index]);
468   // Yield until input dims have been resolved.
469   if (!input_array.has_shape()) {
470     return;
471   }
472 
473   const std::string& output_name = op->outputs[0];
474   auto& output_array = model->GetArray(output_name);
475   if (output_array.has_shape()) {
476     return;
477   }
478 
479   output_array.copy_shape(input_array.shape());
480 }
481 
ProcessSimpleBinaryOperator(Model * model,Operator * op)482 void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
483   CHECK_EQ(op->inputs.size(), 2);
484   const auto& input0_array = model->GetArray(op->inputs[0]);
485   const auto& input1_array = model->GetArray(op->inputs[1]);
486   // Yield until input dims have been resolved.
487   if (!input0_array.has_shape() || !input1_array.has_shape()) {
488     return;
489   }
490   const std::string& output_name = op->outputs[0];
491   auto& output_array = model->GetArray(output_name);
492   ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
493                                   &output_array);
494 }
495 
ProcessSelectOperator(Model * model,SelectOperator * op)496 void ProcessSelectOperator(Model* model, SelectOperator* op) {
497   // Yield until all input dims have been resolved.
498   for (const auto& input : op->inputs) {
499     const auto& input_array = model->GetArray(input);
500     if (!input_array.has_shape()) {
501       return;
502     }
503   }
504 
505   // Select's output matches the second and third output.
506   const auto& input1_array = model->GetArray(op->inputs[1]);
507   auto& output_array = model->GetArray(op->outputs[0]);
508   output_array.copy_shape(input1_array.shape());
509 }
510 
ProcessAddNOperator(Model * model,Operator * op)511 void ProcessAddNOperator(Model* model, Operator* op) {
512   // Yield until all input dims have been resolved.
513   //
514   // TODO(myenik): Since AddN does not support broadcasting, maybe we could
515   // actually use this to improve shape propagation by propagating the shape of
516   // one input to all other inputs once it is resolved instead of just the
517   // output, since all inputs must be the same size and shape for a well-formed
518   // graph.
519   for (const auto& input : op->inputs) {
520     const auto& input_array = model->GetArray(input);
521     if (!input_array.has_shape()) {
522       return;
523     }
524   }
525 
526   // AddN does not support broadcasting, all inputs must be the same shape, so
527   // we just take the first input shape and apply it to the output.
528   const auto& input0_array = model->GetArray(op->inputs[0]);
529   auto& output_array = model->GetArray(op->outputs[0]);
530   output_array.copy_shape(input0_array.shape());
531 }
532 
KeepDims(const Operator & op)533 bool KeepDims(const Operator& op) {
534   switch (op.type) {
535     case OperatorType::kReduceMin:  //  Reduction Min
536       return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
537     case OperatorType::kReduceMax:  //  Reduction Max
538       return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
539     case OperatorType::kSum:
540       return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
541     case OperatorType::kReduceProd:
542       return static_cast<const TensorFlowProdOperator&>(op).keep_dims;
543     case OperatorType::kMean:
544       return static_cast<const MeanOperator&>(op).keep_dims;
545     case OperatorType::kAny:
546       return static_cast<const TensorFlowAnyOperator&>(op).keep_dims;
547     default:
548       LOG(FATAL) << "Not a reduction operator!";
549       return false;
550   }
551 }
552 
ProcessTensorFlowReductionOperator(Model * model,Operator * op)553 void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
554   CHECK_LE(op->inputs.size(), 2);
555   auto& output_array = model->GetArray(op->outputs[0]);
556   if (output_array.has_shape()) {
557     return;
558   }
559   const auto& input_array = model->GetArray(op->inputs[0]);
560   if (!input_array.has_shape()) {
561     return;
562   }
563   const auto& input_shape = input_array.shape();
564   const bool keep_dims = KeepDims(*op);
565   if (op->inputs.size() == 2) {
566     // There is a reduction_indices input.
567     const auto& reduction_indices_array = model->GetArray(op->inputs[1]);
568     if (!reduction_indices_array.buffer) {
569       return;
570     }
571     CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32);
572 
573     int input_rank = input_shape.dimensions_count();
574     std::set<int32> true_indices;
575     const auto& reduction_indices =
576         reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
577     for (size_t i = 0; i < reduction_indices.size(); ++i) {
578       const int32_t reduction_index = reduction_indices[i];
579       if (reduction_index < -input_rank || reduction_index >= input_rank) {
580         CHECK(false) << "Invalid reduction dimension " << reduction_index
581                      << " for input with " << input_rank << " dimensions";
582       }
583       int32_t wrapped_index = reduction_index;
584       if (wrapped_index < 0) {
585         wrapped_index += input_rank;
586       }
587       true_indices.insert(wrapped_index);
588     }
589 
590     auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
591     mutable_dims->clear();
592     for (int i = 0; i < input_rank; ++i) {
593       if (true_indices.count(i) > 0) {
594         if (keep_dims) {
595           mutable_dims->emplace_back(1);
596         }
597       } else {
598         mutable_dims->emplace_back(input_shape.dims(i));
599       }
600     }
601   } else {
602     // No reduction_indices means complete reduction to a single scalar.
603     if (keep_dims) {
604       output_array.copy_shape(input_shape);
605     } else {
606       output_array.copy_shape(Shape({}));
607     }
608   }
609 }
610 
ProcessSliceOperator(Model * model,SliceOperator * op)611 void ProcessSliceOperator(Model* model, SliceOperator* op) {
612   CHECK_EQ(op->inputs.size(), 3);
613   CHECK_EQ(op->outputs.size(), 1);
614 
615   // Yield until the Slice params have been resolved.
616   if (op->begin.empty()) return;
617 
618   // Yield until input dims have been resolved.
619   const auto& input_array = model->GetArray(op->inputs[0]);
620   if (!input_array.has_shape()) return;
621   const Shape& input_shape = input_array.shape();
622 
623   auto& output_array = model->GetArray(op->outputs[0]);
624   if (output_array.has_shape()) return;
625 
626   CHECK_EQ(input_shape.dims().size(), op->size.size());
627   CHECK_EQ(op->begin.size(), op->size.size());
628 
629   std::vector<int> output_dims;
630   for (size_t i = 0; i < op->begin.size(); ++i) {
631     int size = op->size[i];
632     if (size == -1) {
633       size = input_array.shape().dims(i) - op->begin[i];
634     }
635     output_dims.push_back(size);
636   }
637 
638   *output_array.mutable_shape()->mutable_dims() = output_dims;
639 }
640 
ProcessReorderAxesOperator(Model * model,ReorderAxesOperator * op)641 void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
642   const std::string& input_name = op->inputs[0];
643   const auto& input_array = model->GetArray(input_name);
644   // Yield until input dims have been resolved.
645   if (!input_array.has_shape()) {
646     return;
647   }
648   const auto& input_shape = input_array.shape();
649   const std::string& output_name = op->outputs[0];
650   Shape* output_shape = model->GetArray(output_name).mutable_shape();
651   ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
652               output_shape);
653 }
654 
ProcessConcatenationOperator(Model * model,ConcatenationOperator * op)655 void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
656   // Yield until input dims have been resolved.
657   for (const auto& input_name : op->inputs) {
658     auto& input_array = model->GetArray(input_name);
659     if (!input_array.has_shape()) {
660       return;
661     }
662   }
663   auto& output_array = model->GetArray(op->outputs[0]);
664   // Use first non-empty input as basis for output dimensions.
665   for (const auto& input_name : op->inputs) {
666     const auto& input_array = model->GetArray(input_name);
667     if (input_array.shape().dimensions_count() > 0) {
668       output_array.copy_shape(input_array.shape());
669       // Negative axis means the count starts at the back of the dims().
670       if (op->axis < 0) op->axis += input_array.shape().dims().size();
671       break;
672     }
673   }
674   // Determine the concat size, and enforce that all inputs have
675   // the same dimensions count.
676   int concat_size = 0;
677   for (const auto& input_name : op->inputs) {
678     auto& input_array = model->GetArray(input_name);
679     CHECK(input_array.has_shape());
680     if (input_array.shape().dimensions_count() == 0) {
681       continue;
682     }
683     CHECK_EQ(input_array.shape().dimensions_count(),
684              output_array.shape().dimensions_count());
685     const std::vector<int>& input_dims = input_array.shape().dims();
686     CHECK_LT(op->axis, input_dims.size());
687     concat_size += input_dims[op->axis];
688   }
689   // Write out the concat_size on the output array shape.
690   auto& output_shape = *output_array.mutable_shape();
691   auto& output_dims = *output_shape.mutable_dims();
692   CHECK_LT(op->axis, output_shape.dimensions_count());
693   output_dims[op->axis] = concat_size;
694 }
695 
ProcessRangeOperator(Model * model,RangeOperator * op)696 void ProcessRangeOperator(Model* model, RangeOperator* op) {
697   CHECK_EQ(op->inputs.size(), 3);
698   const auto& start_array = model->GetArray(op->inputs[0]);
699   if (!start_array.has_shape()) {
700     // Yield until input dims have been resolved.
701     return;
702   }
703   const auto& limit_array = model->GetArray(op->inputs[1]);
704   if (!limit_array.has_shape()) {
705     return;
706   }
707   const auto& delta_array = model->GetArray(op->inputs[2]);
708   if (!delta_array.has_shape()) {
709     return;
710   }
711 
712   if (!IsConstantParameterArray(*model, op->inputs[0])) {
713     // Yield until inputs are constant.
714     return;
715   }
716   if (!IsConstantParameterArray(*model, op->inputs[1])) {
717     return;
718   }
719   if (!IsConstantParameterArray(*model, op->inputs[2])) {
720     return;
721   }
722 
723   const ArrayDataType& start_dtype = start_array.data_type;
724   CHECK(start_dtype == ArrayDataType::kInt32 ||
725         start_dtype == ArrayDataType::kFloat)
726       << "Range op inputs must be int32 or float.";
727   CHECK(limit_array.data_type == start_dtype)
728       << "In Range op, limit tensor must have the same data type as start "
729          "tensor.";
730   CHECK(delta_array.data_type == start_dtype)
731       << "In Range op, delta tensor must have the same data type as start "
732          "tensor.";
733   CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
734       << "Range op inputs must be scalar.";
735   CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1)
736       << "Range op inputs must be scalar.";
737   CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1)
738       << "Range op inputs must be scalar.";
739 
740   int size = 0;
741   if (start_dtype == ArrayDataType::kInt32) {
742     size = std::floor((limit_array.GetBuffer<ArrayDataType::kInt32>().data[0] -
743                        start_array.GetBuffer<ArrayDataType::kInt32>().data[0]) /
744                       delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]);
745   } else if (start_dtype == ArrayDataType::kFloat) {
746     size = std::floor((limit_array.GetBuffer<ArrayDataType::kFloat>().data[0] -
747                        start_array.GetBuffer<ArrayDataType::kFloat>().data[0]) /
748                       delta_array.GetBuffer<ArrayDataType::kFloat>().data[0]);
749   }
750 
751   // Only set the output shape. Contents are set by ResolveConstantRange.
752   CHECK_EQ(op->outputs.size(), 1);
753   auto& output_array = model->GetArray(op->outputs[0]);
754   Shape* output_shape = output_array.mutable_shape();
755   output_shape->ReplaceDims({size});
756 }
757 
ProcessTensorFlowSplitOperator(Model * model,TensorFlowSplitOperator * op)758 void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
759   CHECK_EQ(op->inputs.size(), 2);
760   const std::string& input_name = op->inputs[1];
761   const auto& input_array = model->GetArray(input_name);
762   // Yield until input dims have been resolved.
763   if (!input_array.has_shape()) {
764     return;
765   }
766   const Shape& input_shape = input_array.shape();
767 
768   // Yield until axis is constant.
769   if (!IsConstantParameterArray(*model, op->inputs[0])) {
770     return;
771   }
772 
773   const auto& axis_array = model->GetArray(op->inputs[0]);
774 
775   // Yield until axis dims have been resolved.
776   if (!axis_array.has_shape()) {
777     return;
778   }
779 
780   CHECK(axis_array.data_type == ArrayDataType::kInt32)
781       << "Axis array must be int32.";
782   CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
783       << "Axis array must be scalar.";
784 
785   int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
786   if (axis < 0) {
787     axis += input_shape.dimensions_count();
788   }
789 
790   const int split_dim = input_shape.dims(axis);
791   CHECK_EQ(split_dim % op->num_split, 0);
792   const int split_depth = split_dim / op->num_split;
793 
794   Shape output_shape = input_shape;
795   (*output_shape.mutable_dims())[axis] = split_depth;
796 
797   CHECK_EQ(op->outputs.size(), op->num_split);
798   for (const auto& output : op->outputs) {
799     model->GetArray(output).copy_shape(output_shape);
800   }
801 }
802 
ProcessTensorFlowSplitVOperator(Model * model,TensorFlowSplitVOperator * op)803 void ProcessTensorFlowSplitVOperator(Model* model,
804                                      TensorFlowSplitVOperator* op) {
805   CHECK_EQ(op->inputs.size(), 3);
806 
807   const auto& input_array = model->GetArray(op->inputs[0]);
808   // Yield until input dims have been resolved.
809   if (!input_array.has_shape()) {
810     return;
811   }
812   const Shape& input_shape = input_array.shape();
813 
814   // Yield until size_splits is constant.
815   if (!IsConstantParameterArray(*model, op->inputs[1])) {
816     return;
817   }
818   const auto& size_array = model->GetArray(op->inputs[1]);
819   // Yield until size_splits dims have been resolved.
820   if (!size_array.has_shape()) {
821     return;
822   }
823   const Shape& size_shape = size_array.shape();
824 
825   CHECK(size_array.data_type == ArrayDataType::kInt32 ||
826         size_array.data_type == ArrayDataType::kInt64)
827       << "size_splits must be int32, int64";
828   CHECK_EQ(size_shape.dimensions_count(), 1) << "size_splits must be 1-D";
829 
830   std::vector<int64_t> size_splits_vector;
831   if (size_array.data_type == ArrayDataType::kInt32) {
832     for (const auto each_size :
833          size_array.GetBuffer<ArrayDataType::kInt32>().data) {
834       size_splits_vector.push_back(each_size);
835     }
836   } else {
837     size_splits_vector = size_array.GetBuffer<ArrayDataType::kInt64>().data;
838   }
839 
840   // Yield until axis is constant.
841   if (!IsConstantParameterArray(*model, op->inputs[2])) {
842     return;
843   }
844   const auto& axis_array = model->GetArray(op->inputs[2]);
845   // Yield until axis dims have been resolved.
846   if (!axis_array.has_shape()) {
847     return;
848   }
849 
850   CHECK(axis_array.data_type == ArrayDataType::kInt32)
851       << "Axis array must be int32.";
852   CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
853       << "Axis array must be scalar.";
854 
855   int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
856   if (axis < 0) {
857     axis += input_shape.dimensions_count();
858   }
859 
860   CHECK_EQ(op->num_split, size_splits_vector.size());
861 
862   int64_t minus_one_count = 0, size_splits_sum = 0;
863   for (auto size : size_splits_vector) {
864     if (size == -1) {
865       ++minus_one_count;
866     } else {
867       size_splits_sum += size;
868     }
869   }
870 
871   const int input_size = input_shape.dims(axis);
872 
873   CHECK_LE(minus_one_count, 1) << "size_splits can contain at most one -1.";
874 
875   if (minus_one_count == 1) {
876     CHECK_LE(size_splits_sum, input_size);
877     auto iter =
878         std::find(size_splits_vector.begin(), size_splits_vector.end(), -1);
879     *iter = input_size - size_splits_sum;
880   } else {
881     CHECK_EQ(size_splits_sum, input_size);
882   }
883 
884   CHECK_EQ(op->outputs.size(), op->num_split);
885 
886   for (size_t i = 0; i < op->outputs.size(); ++i) {
887     const auto& output = op->outputs[i];
888     Shape output_shape = input_shape;
889     (*output_shape.mutable_dims())[axis] = size_splits_vector.at(i);
890     model->GetArray(output).copy_shape(output_shape);
891   }
892 }
893 
ProcessAveragePoolOperator(Model * model,AveragePoolOperator * op)894 void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
895   const std::string& input_name = op->inputs[0];
896   const auto& input_array = model->GetArray(input_name);
897   // Yield until input dims have been resolved.
898   if (!input_array.has_shape()) {
899     return;
900   }
901   const auto& input_shape = input_array.shape();
902   CHECK_EQ(input_shape.dimensions_count(), 4);
903   const std::string& output_name = op->outputs[0];
904   const int output_depth = input_shape.dims(3);
905   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
906                    op->stride_width, op->stride_height, 1, 1, op->padding.type,
907                    model->GetArray(output_name).mutable_shape(),
908                    &op->padding.GetOrCreateFixedPadding());
909 }
910 
ProcessMaxPoolOperator(Model * model,MaxPoolOperator * op)911 void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
912   const std::string& input_name = op->inputs[0];
913   const auto& input_array = model->GetArray(input_name);
914   // Yield until input dims have been resolved.
915   if (!input_array.has_shape()) {
916     return;
917   }
918   const auto& input_shape = input_array.shape();
919   CHECK_EQ(input_shape.dimensions_count(), 4);
920   const std::string& output_name = op->outputs[0];
921   const int output_depth = input_shape.dims(3);
922   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
923                    op->stride_width, op->stride_height, 1, 1, op->padding.type,
924                    model->GetArray(output_name).mutable_shape(),
925                    &op->padding.GetOrCreateFixedPadding());
926 }
927 
ProcessL2PoolOperator(Model * model,L2PoolOperator * op)928 void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
929   const std::string& input_name = op->inputs[0];
930   const auto& input_array = model->GetArray(input_name);
931   // Yield until input dims have been resolved.
932   if (!input_array.has_shape()) {
933     return;
934   }
935   const auto& input_shape = input_array.shape();
936   if (input_shape.dimensions_count() < 4) {
937     LOG(FATAL) << "missing dimensions for " << input_name;
938   }
939   const std::string& output_name = op->outputs[0];
940   const int output_depth = input_shape.dims(3);
941   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
942                    op->stride_width, op->stride_height, 1, 1, op->padding.type,
943                    model->GetArray(output_name).mutable_shape(),
944                    &op->padding.GetOrCreateFixedPadding());
945 }
946 
ProcessResizeBilinearOperator(Model * model,ResizeBilinearOperator * op)947 void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
948   CHECK_EQ(op->inputs.size(), 2);
949   CHECK_EQ(op->outputs.size(), 1);
950 
951   if (!model->GetArray(op->inputs[0]).has_shape() ||
952       !model->GetArray(op->inputs[1]).has_shape()) {
953     return;
954   }
955   const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
956 
957   const std::string& output_size_name = op->inputs[1];
958   const auto& output_size_array = model->GetArray(output_size_name);
959   CHECK(output_size_array.data_type == ArrayDataType::kInt32);
960   CHECK(output_size_array.has_shape());
961   const auto& output_size_shape = output_size_array.shape();
962   CHECK_EQ(output_size_shape.dimensions_count(), 1);
963   CHECK_EQ(output_size_shape.dims(0), 2);
964   if (!output_size_array.buffer) {
965     return;
966   }
967   std::vector<int32> output_shape =
968       output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
969   model->GetArray(op->outputs[0])
970       .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
971                          output_shape[1], input_data_shape.dims(3)}));
972 }
973 
ProcessResizeNearestNeighborOperator(Model * model,ResizeNearestNeighborOperator * op)974 void ProcessResizeNearestNeighborOperator(Model* model,
975                                           ResizeNearestNeighborOperator* op) {
976   CHECK_EQ(op->inputs.size(), 2);
977   CHECK_EQ(op->outputs.size(), 1);
978 
979   if (!model->GetArray(op->inputs[0]).has_shape() ||
980       !model->GetArray(op->inputs[1]).has_shape()) {
981     return;
982   }
983   const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
984 
985   const std::string& output_size_name = op->inputs[1];
986   const auto& output_size_array = model->GetArray(output_size_name);
987   CHECK(output_size_array.data_type == ArrayDataType::kInt32);
988   CHECK(output_size_array.has_shape());
989   const auto& output_size_shape = output_size_array.shape();
990   CHECK_EQ(output_size_shape.dimensions_count(), 1);
991   CHECK_EQ(output_size_shape.dims(0), 2);
992   if (!output_size_array.buffer) {
993     return;
994   }
995   std::vector<int32> output_shape =
996       output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
997   model->GetArray(op->outputs[0])
998       .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
999                          output_shape[1], input_data_shape.dims(3)}));
1000 }
1001 
ProcessLstmCellOperator(Model * model,LstmCellOperator * op)1002 void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
1003   // Only required for compact LstmCell with default NUM_INPUTS of inputs.
1004   if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return;
1005 
1006   const auto& input_array =
1007       model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
1008   // Yield until all input dims have been resolved.
1009   if (!input_array.has_shape()) {
1010     return;
1011   }
1012   const auto& input_shape = input_array.shape();
1013   CHECK_GE(input_shape.dimensions_count(), 2);
1014 
1015   const auto& prev_activ_array =
1016       model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
1017   // Yield until all input dims have been resolved.
1018   if (!prev_activ_array.has_shape()) {
1019     return;
1020   }
1021   const auto& prev_activ_shape = prev_activ_array.shape();
1022   CHECK_GE(prev_activ_shape.dimensions_count(), 2);
1023 
1024   const auto& weights_array =
1025       model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
1026   // Yield until weights dims have been resolved.
1027   if (!weights_array.has_shape()) {
1028     return;
1029   }
1030   const auto& weights_shape = weights_array.shape();
1031   CHECK_EQ(weights_shape.dimensions_count(), 2);
1032 
1033   const auto& bias_array =
1034       model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
1035   // Yield until bias dims have been resolved.
1036   if (!bias_array.has_shape()) {
1037     return;
1038   }
1039   const auto& bias_shape = bias_array.shape();
1040   CHECK_GE(bias_shape.dimensions_count(), 1);
1041 
1042   const auto& prev_state_array =
1043       model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
1044   // Yield until all input dims have been resolved.
1045   if (!prev_state_array.has_shape()) {
1046     return;
1047   }
1048   const auto& prev_state_shape = prev_state_array.shape();
1049   CHECK_GE(prev_state_shape.dimensions_count(), 2);
1050 
1051   const int fc_output_depth = weights_shape.dims(0);
1052   CHECK_EQ(fc_output_depth, bias_shape.dims(0));
1053   CHECK_EQ(fc_output_depth % 4, 0);
1054   const int depth = fc_output_depth / 4;
1055 
1056   const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
1057   const int fc_input_depth = weights_shape.dims(1);
1058   CHECK_EQ(input_depth + depth, fc_input_depth);
1059   Shape output_shape(input_shape);
1060   (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;
1061 
1062   // Set output dimensions
1063   model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
1064       .copy_shape(output_shape);
1065   model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
1066       .copy_shape(output_shape);
1067 
1068   Shape concat_temp_shape(input_shape);
1069   (*concat_temp_shape
1070         .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
1071       fc_input_depth;
1072   model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
1073       .copy_shape(concat_temp_shape);
1074 
1075   Shape activ_temp_shape(input_shape);
1076   (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
1077       fc_output_depth;
1078   model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
1079       .copy_shape(activ_temp_shape);
1080 }
1081 
ProcessUnidirectionalSequenceLstmOperator(Model * model,UnidirectionalSequenceLstmOperator * op)1082 void ProcessUnidirectionalSequenceLstmOperator(
1083     Model* model, UnidirectionalSequenceLstmOperator* op) {
1084   auto& output_array = model->GetArray(op->outputs[0]);
1085   if (output_array.has_shape()) {
1086     // Shape already propagated
1087     return;
1088   }
1089 
1090   if (output_array.data_type == ArrayDataType::kNone) {
1091     // Yield until the output type has been set by PropagateArrayDataTypes
1092     return;
1093   }
1094 
1095   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1096   const auto& input_array = model->GetArray(op->inputs[0]);
1097 
1098   constexpr int kInputActivationStateTensor = 18;
1099   constexpr int kInputCellStateTensor = 19;
1100 
1101   // TFlite interpreter does not support array which is variable and contains a
1102   // buffer (see b/115961645 for more discussion).
1103   // The follow block remove buffer from the array to work around the
1104   // restriction, as a consequence, downstream applications should not
1105   // read lstm state as input to other operations.
1106   model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset();
1107   model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset();
1108 
1109   // Yield until input dims have been resolved.
1110   if (!input_array.has_shape()) {
1111     return;
1112   }
1113   const auto& input_shape = input_array.shape();
1114   const int batch_size = input_shape.dims(1);
1115   const int timestamp = input_shape.dims(0);
1116 
1117   const auto& recurrent_to_output_weights_array =
1118       model->GetArray(op->inputs[8]);
1119   // Yield until input dims have been resolved.
1120   if (!recurrent_to_output_weights_array.has_shape()) {
1121     return;
1122   }
1123 
1124   const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1125   const int output_size = output_weights_shape.dims(1);
1126 
1127   Shape* output_shape = output_array.mutable_shape();
1128   output_shape->ReplaceDims({timestamp, batch_size, output_size});
1129 }
1130 
ProcessUnidirectionalSequenceRnnOperator(Model * model,UnidirectionalSequenceRnnOperator * op)1131 void ProcessUnidirectionalSequenceRnnOperator(
1132     Model* model, UnidirectionalSequenceRnnOperator* op) {
1133   auto& output_array = model->GetArray(op->outputs[0]);
1134   if (output_array.has_shape()) {
1135     // Shape already propagated.
1136     return;
1137   }
1138 
1139   if (output_array.data_type == ArrayDataType::kNone) {
1140     // Yield until the output type has been set by PropagateArrayDataTypes
1141     return;
1142   }
1143 
1144   constexpr int kHiddenStateTensor = 4;
1145   // TFlite interpreter does not support array which is variable and contains a
1146   // buffer (see b/115961645 for more discussion).
1147   // The follow block remove buffer from the array to work around the
1148   // restriction, as a consequence, downstream applications should not
1149   // read lstm state as input to other operations.
1150   model->GetArray(op->inputs[kHiddenStateTensor]).buffer.reset();
1151 
1152   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1153   const auto& input_array = model->GetArray(op->inputs[0]);
1154   // Yield until input dims have been resolved.
1155   if (!input_array.has_shape()) {
1156     return;
1157   }
1158   const auto& input_shape = input_array.shape();
1159   const int batch_size = input_shape.dims(1);
1160   const int timestamp = input_shape.dims(0);
1161 
1162   const auto& bias_array = model->GetArray(op->inputs[3]);
1163   // Yield until input dims have been resolved.
1164   if (!bias_array.has_shape()) {
1165     return;
1166   }
1167 
1168   const auto& bias_shape = bias_array.shape();
1169   const int output_size = bias_shape.dims(0);
1170 
1171   Shape* output_shape = output_array.mutable_shape();
1172   output_shape->ReplaceDims({timestamp, batch_size, output_size});
1173 }
1174 
ProcessBidirectionalSequenceLstmOperator(Model * model,BidirectionalSequenceLstmOperator * op)1175 void ProcessBidirectionalSequenceLstmOperator(
1176     Model* model, BidirectionalSequenceLstmOperator* op) {
1177   // We assume time major.
1178   auto& fw_output_array = model->GetArray(op->outputs[0]);
1179   auto& bw_output_array = model->GetArray(op->outputs[1]);
1180   if (fw_output_array.has_shape()) {
1181     // Shape already propagated
1182     return;
1183   }
1184 
1185   if (fw_output_array.data_type == ArrayDataType::kNone) {
1186     // Yield until the output type has been set by PropagateArrayDataTypes
1187     return;
1188   }
1189 
1190   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1191   const auto& input_array = model->GetArray(op->inputs[0]);
1192   // Yield until input dims have been resolved.
1193   if (!input_array.has_shape()) {
1194     return;
1195   }
1196   const auto& input_shape = input_array.shape();
1197   const int batch_size = input_shape.dims(1);
1198   const int timestamp = input_shape.dims(0);
1199 
1200   constexpr int kBwRecurrentToOutputWeightsTensor = 25;
1201   const auto& recurrent_to_output_weights_array =
1202       model->GetArray(op->inputs[kBwRecurrentToOutputWeightsTensor]);
1203   // Yield until input dims have been resolved.
1204   if (!recurrent_to_output_weights_array.has_shape()) {
1205     return;
1206   }
1207 
1208   constexpr int kFwInputActivationStateTensor = 35;
1209   constexpr int kFwInputCellStateTensor = 36;
1210   constexpr int kBwInputActivationStateTensor = 37;
1211   constexpr int kBwInputCellStateTensor = 38;
1212   // b(115961645): This is a hack to work around.
1213   model->GetArray(op->inputs[kFwInputActivationStateTensor]).buffer.reset();
1214   model->GetArray(op->inputs[kFwInputCellStateTensor]).buffer.reset();
1215   model->GetArray(op->inputs[kBwInputActivationStateTensor]).buffer.reset();
1216   model->GetArray(op->inputs[kBwInputCellStateTensor]).buffer.reset();
1217 
1218   const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1219   const int output_size = output_weights_shape.dims(1);
1220 
1221   Shape* fw_output_shape = fw_output_array.mutable_shape();
1222   if (op->merge_outputs) {
1223     fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1224   } else {
1225     fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1226     Shape* bw_output_shape = bw_output_array.mutable_shape();
1227     bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1228   }
1229 }
1230 
ProcessBidirectionalSequenceRnnOperator(Model * model,BidirectionalSequenceRnnOperator * op)1231 void ProcessBidirectionalSequenceRnnOperator(
1232     Model* model, BidirectionalSequenceRnnOperator* op) {
1233   // We assume time major.
1234   auto& fw_output_array = model->GetArray(op->outputs[0]);
1235   auto& bw_output_array = model->GetArray(op->outputs[1]);
1236   if (fw_output_array.has_shape()) {
1237     // Shape already propagated
1238     return;
1239   }
1240 
1241   if (fw_output_array.data_type == ArrayDataType::kNone) {
1242     // Yield until the output type has been set by PropagateArrayDataTypes
1243     return;
1244   }
1245 
1246   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1247   const auto& input_array = model->GetArray(op->inputs[0]);
1248   // Yield until input dims have been resolved.
1249   if (!input_array.has_shape()) {
1250     return;
1251   }
1252   const auto& input_shape = input_array.shape();
1253   const int batch_size = input_shape.dims(1);
1254   const int timestamp = input_shape.dims(0);
1255 
1256   constexpr int kFwWeightsTensor = 1;
1257   const auto& forward_weights_array =
1258       model->GetArray(op->inputs[kFwWeightsTensor]);
1259   // Yield until input dims have been resolved.
1260   if (!forward_weights_array.has_shape()) {
1261     return;
1262   }
1263 
1264   constexpr int kFwHiddenStateTensor = 4;
1265   constexpr int kBwHiddenStateTensor = 8;
1266   // b(115961645): This is a hack to work around.
1267   model->GetArray(op->inputs[kFwHiddenStateTensor]).buffer.reset();
1268   model->GetArray(op->inputs[kBwHiddenStateTensor]).buffer.reset();
1269 
1270   const auto& output_weights_shape = forward_weights_array.shape();
1271   const int output_size = output_weights_shape.dims(0);
1272 
1273   Shape* fw_output_shape = fw_output_array.mutable_shape();
1274   if (op->merge_outputs) {
1275     fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1276   } else {
1277     fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1278     Shape* bw_output_shape = bw_output_array.mutable_shape();
1279     bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1280   }
1281 }
1282 
ProcessSpaceToBatchNDOperator(Model * model,SpaceToBatchNDOperator * op)1283 void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
1284   const auto& input_array = model->GetArray(op->inputs[0]);
1285   // Yield until input dims have been resolved.
1286   if (!input_array.has_shape()) {
1287     return;
1288   }
1289   const auto& input_shape = input_array.shape();
1290   // This method only handles input dimensions of 3 or 4.
1291   if (input_shape.dimensions_count() != 3 &&
1292       input_shape.dimensions_count() != 4) {
1293     return;
1294   }
1295 
1296   const auto& block_shape_array = model->GetArray(op->inputs[1]);
1297   const auto& paddings_array = model->GetArray(op->inputs[2]);
1298   const auto& block_shape_array_shape = block_shape_array.shape();
1299   const auto& paddings_array_shape = paddings_array.shape();
1300   QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1301   QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);
1302 
1303   int spatial_dims_num = input_shape.dimensions_count() - 2;
1304   QCHECK_EQ(block_shape_array_shape.dims(0), spatial_dims_num);
1305   if (!block_shape_array.buffer) {
1306     return;
1307   }
1308   QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1309   const auto& block_shape_data =
1310       block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1311 
1312   QCHECK_EQ(paddings_array_shape.dims(0), spatial_dims_num);
1313   QCHECK_EQ(paddings_array_shape.dims(1), 2);  // Two parameters per dimension.
1314   if (!paddings_array.buffer) {
1315     return;
1316   }
1317   QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
1318   const auto& paddings_data =
1319       paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
1320 
1321   Shape output_shape(input_shape);
1322   std::vector<int>* output_shape_data = output_shape.mutable_dims();
1323   int output_batch_size = input_shape.dims(0);
1324   for (int dim = 0; dim < spatial_dims_num; ++dim) {
1325     int final_dim_size = (input_shape.dims(dim + 1) + paddings_data[dim * 2] +
1326                           paddings_data[dim * 2 + 1]);
1327     QCHECK_EQ(final_dim_size % block_shape_data[dim], 0);
1328     output_shape_data->at(dim + 1) = final_dim_size / block_shape_data[dim];
1329     output_batch_size *= block_shape_data[dim];
1330   }
1331 
1332   output_shape_data->at(0) = output_batch_size;
1333   output_shape_data->at(input_shape.dimensions_count() - 1) =
1334       input_shape.dims(input_shape.dimensions_count() - 1);
1335 
1336   model->GetArray(op->outputs[0]).copy_shape(output_shape);
1337 }
1338 
ProcessBatchToSpaceNDOperator(Model * model,BatchToSpaceNDOperator * op)1339 void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
1340   const auto& input_array = model->GetArray(op->inputs[0]);
1341   // Yield until input dims have been resolved.
1342   if (!input_array.has_shape()) {
1343     return;
1344   }
1345   const auto& input_shape = input_array.shape();
1346   CHECK_GE(input_shape.dimensions_count(), 3);
1347   CHECK_LE(input_shape.dimensions_count(), 4);
1348   int spatial_dims_num = input_shape.dimensions_count() - 2;
1349 
1350   const auto& block_shape_array = model->GetArray(op->inputs[1]);
1351   const auto& crops_array = model->GetArray(op->inputs[2]);
1352   const auto& block_shape_array_shape = block_shape_array.shape();
1353   const auto& crops_array_shape = crops_array.shape();
1354   QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1355   QCHECK_EQ(crops_array_shape.dimensions_count(), 2);
1356 
1357   // We only support two dimensions.
1358   QCHECK_EQ(block_shape_array_shape.dims(0), spatial_dims_num);
1359   if (!block_shape_array.buffer) {
1360     return;
1361   }
1362   QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1363   const auto& block_shape_data =
1364       block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1365 
1366   QCHECK_EQ(crops_array_shape.dims(0), spatial_dims_num);
1367   QCHECK_EQ(crops_array_shape.dims(1), 2);  // Two parameters per dimension.
1368   if (!crops_array.buffer) {
1369     return;
1370   }
1371   QCHECK(crops_array.data_type == ArrayDataType::kInt32);
1372   const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
1373 
1374   Shape output_shape(input_shape);
1375   std::vector<int>* output_shape_data = output_shape.mutable_dims();
1376   int output_batch_size = input_shape.dims(0);
1377   for (int dim = 0; dim < spatial_dims_num; ++dim) {
1378     // Number of batch must be multiple of (block_shape[dim]).
1379     QCHECK_EQ(output_batch_size % block_shape_data[dim], 0);
1380     output_batch_size = output_batch_size / block_shape_data[dim];
1381     output_shape_data->at(dim + 1) =
1382         input_shape.dims(dim + 1) * block_shape_data[dim] -
1383         crops_data[dim * 2] - crops_data[dim * 2 + 1];
1384   }
1385   output_shape_data->at(0) = output_batch_size;
1386   output_shape_data->at(input_shape.dimensions_count() - 1) =
1387       input_shape.dims(input_shape.dimensions_count() - 1);
1388 
1389   model->GetArray(op->outputs[0]).copy_shape(output_shape);
1390 }
1391 
ProcessGatherOperator(Model * model,GatherOperator * op)1392 void ProcessGatherOperator(Model* model, GatherOperator* op) {
1393   const auto& input_array = model->GetArray(op->inputs[0]);
1394   const auto& indices_array = model->GetArray(op->inputs[1]);
1395   auto& output_array = model->GetArray(op->outputs[0]);
1396 
1397   // Bail if we already know the output shape.
1398   if (output_array.has_shape()) {
1399     return;
1400   }
1401 
1402   // Yield until input dims have been resolved.
1403   if (!input_array.has_shape() || !indices_array.has_shape()) {
1404     return;
1405   }
1406 
1407   // Yield until the axis has been resolved.
1408   if (!op->axis) {
1409     return;
1410   }
1411   int axis = op->axis.value();
1412 
1413   const auto& input_shape = input_array.shape();
1414   const auto& indices_shape = indices_array.shape();
1415   QCHECK_GE(input_shape.dimensions_count(), 1);
1416   op->input_rank = input_shape.dimensions_count();
1417   QCHECK_LT(axis, op->input_rank);
1418 
1419   // Copy the input dimensions to the output except for the axis dimensions
1420   // where the dimension of indices_shape is used.
1421   auto output_dims = output_array.mutable_shape()->mutable_dims();
1422   for (int dim = 0; dim < axis; ++dim) {
1423     output_dims->push_back(input_shape.dims(dim));
1424   }
1425   for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) {
1426     output_dims->push_back(indices_shape.dims(dim));
1427   }
1428   for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) {
1429     output_dims->push_back(input_shape.dims(dim));
1430   }
1431 }
1432 
ProcessGatherNdOperator(Model * model,GatherNdOperator * op)1433 void ProcessGatherNdOperator(Model* model, GatherNdOperator* op) {
1434   const auto& input_array = model->GetArray(op->inputs[0]);
1435   const auto& indices_array = model->GetArray(op->inputs[1]);
1436   auto& output_array = model->GetArray(op->outputs[0]);
1437 
1438   // Bail if we already know the output shape.
1439   if (output_array.has_shape()) {
1440     return;
1441   }
1442 
1443   // Yield until input dims have been resolved.
1444   if (!input_array.has_shape() || !indices_array.has_shape()) {
1445     return;
1446   }
1447 
1448   const auto& input_shape = input_array.shape();
1449   const auto& indices_shape = indices_array.shape();
1450   QCHECK_GE(input_shape.dimensions_count(), 1);
1451   QCHECK_GE(indices_shape.dimensions_count(), 1);
1452   const int indices_nd =
1453       indices_shape.dims(indices_shape.dimensions_count() - 1);
1454   QCHECK_LE(indices_nd, input_shape.dimensions_count());
1455 
1456   auto output_dims = output_array.mutable_shape()->mutable_dims();
1457   for (int dim = 0; dim < indices_shape.dimensions_count() - 1; ++dim) {
1458     output_dims->push_back(indices_shape.dims(dim));
1459   }
1460   for (int dim = indices_nd; dim < input_shape.dimensions_count(); ++dim) {
1461     output_dims->push_back(input_shape.dims(dim));
1462   }
1463 }
1464 
ProcessTopkV2Operator(Model * model,TopKV2Operator * op)1465 void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
1466   const auto& input_values = model->GetArray(op->inputs[0]);
1467   const auto& input_k = model->GetArray(op->inputs[1]);
1468   auto& output_values = model->GetArray(op->outputs[0]);
1469   auto& output_indexes = model->GetArray(op->outputs[1]);
1470 
1471   // Bail if we already know the output shape.
1472   if (output_indexes.has_shape()) {
1473     QCHECK(output_values.has_shape());
1474     return;
1475   }
1476 
1477   // Yield until input dims have been resolved.
1478   if (!input_values.has_shape() || !input_k.has_shape()) {
1479     return;
1480   }
1481 
1482   // If the value is initialized, we can specify the last dimension, otherwise
1483   // unknown.
1484   if (input_k.buffer) {
1485     const auto& input_values_shape = input_values.shape();
1486     auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims();
1487     auto output_values_dims = output_values.mutable_shape()->mutable_dims();
1488     for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) {
1489       output_indexes_dims->push_back(input_values_shape.dims(dim));
1490       output_values_dims->push_back(input_values_shape.dims(dim));
1491     }
1492     const int32_t k_value = input_k.GetBuffer<ArrayDataType::kInt32>().data[0];
1493     output_indexes_dims->push_back(k_value);
1494     output_values_dims->push_back(k_value);
1495   }
1496 }
1497 
ProcessPadOperator(Model * model,PadOperator * op)1498 void ProcessPadOperator(Model* model, PadOperator* op) {
1499   CHECK_EQ(op->inputs.size(), 2);
1500   CHECK_EQ(op->outputs.size(), 1);
1501 
1502   const auto& input_array = model->GetArray(op->inputs[0]);
1503 
1504   // Yield until input dims have been resolved.
1505   if (!input_array.has_shape()) return;
1506 
1507   if (op->left_padding.empty()) return;
1508   CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1509 
1510   auto& output_array = model->GetArray(op->outputs[0]);
1511   if (output_array.has_shape()) return;
1512 
1513   Shape output_shape = input_array.shape();
1514   std::vector<int>& dims = *output_shape.mutable_dims();
1515   CHECK_EQ(op->left_padding.size(), dims.size());
1516 
1517   for (size_t i = 0; i < op->left_padding.size(); ++i) {
1518     dims[i] += op->left_padding[i] + op->right_padding[i];
1519   }
1520 
1521   output_array.copy_shape(output_shape);
1522 }
1523 
ProcessPadV2Operator(Model * model,PadV2Operator * op)1524 void ProcessPadV2Operator(Model* model, PadV2Operator* op) {
1525   CHECK_EQ(op->inputs.size(), 3);
1526   CHECK_EQ(op->outputs.size(), 1);
1527 
1528   const auto& input_array = model->GetArray(op->inputs[0]);
1529 
1530   // Yield until input dims have been resolved.
1531   if (!input_array.has_shape()) return;
1532 
1533   if (op->left_padding.empty()) return;
1534   CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1535 
1536   auto& output_array = model->GetArray(op->outputs[0]);
1537   if (output_array.has_shape()) return;
1538 
1539   Shape output_shape = input_array.shape();
1540   std::vector<int>& dims = *output_shape.mutable_dims();
1541   CHECK_EQ(op->left_padding.size(), dims.size());
1542 
1543   for (size_t i = 0; i < op->left_padding.size(); ++i) {
1544     dims[i] += op->left_padding[i] + op->right_padding[i];
1545   }
1546 
1547   output_array.copy_shape(output_shape);
1548 }
1549 
ProcessRankOperator(Model * model,TensorFlowRankOperator * op)1550 void ProcessRankOperator(Model* model, TensorFlowRankOperator* op) {
1551   CHECK_GE(op->inputs.size(), 1);
1552   CHECK_EQ(op->outputs.size(), 1);
1553   auto& output_array = model->GetArray(op->outputs[0]);
1554   if (output_array.has_shape()) {
1555     // Shape already propagated
1556     return;
1557   }
1558 
1559   if (output_array.data_type == ArrayDataType::kNone) {
1560     // Yield until the output type has been set by PropagateArrayDataTypes
1561     return;
1562   }
1563 
1564   const auto& input_array = model->GetArray(op->inputs[0]);
1565   if (!input_array.has_shape()) {
1566     // Yield until input dims have been resolved.
1567     return;
1568   }
1569 
1570   // Only set the output shape. Array contents are set by
1571   // ResolveConstantShapeOrRank.
1572   Shape* output_shape = output_array.mutable_shape();
1573   output_shape->ReplaceDims({});
1574 }
1575 
ProcessShapeOperator(Model * model,TensorFlowShapeOperator * op)1576 void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
1577   CHECK_GE(op->inputs.size(), 1);
1578   CHECK_EQ(op->outputs.size(), 1);
1579   auto& output_array = model->GetArray(op->outputs[0]);
1580   if (output_array.has_shape()) {
1581     // Shape already propagated
1582     return;
1583   }
1584 
1585   if (output_array.data_type == ArrayDataType::kNone) {
1586     // Yield until the output type has been set by PropagateArrayDataTypes
1587     return;
1588   }
1589 
1590   const auto& input_array = model->GetArray(op->inputs[0]);
1591   if (!input_array.has_shape()) {
1592     // Yield until input dims have been resolved.
1593     return;
1594   }
1595 
1596   // Only set the output shape. Array contents are set by
1597   // ResolveConstantShapeOrRank.
1598   Shape* output_shape = output_array.mutable_shape();
1599   output_shape->ReplaceDims({input_array.shape().dimensions_count()});
1600 }
1601 
ProcessPackOperator(Model * model,PackOperator * op)1602 void ProcessPackOperator(Model* model, PackOperator* op) {
1603   CHECK_GE(op->inputs.size(), 1);
1604   CHECK_EQ(op->outputs.size(), 1);
1605   auto& output_array = model->GetArray(op->outputs[0]);
1606   if (output_array.has_shape()) {
1607     // Shape already propagated
1608     return;
1609   }
1610 
1611   std::unique_ptr<Shape> packed_shape;
1612   for (const auto& input : op->inputs) {
1613     const auto& input_array = model->GetArray(input);
1614     if (!input_array.has_shape()) {
1615       // Yield until all input dims have been resolved.
1616       return;
1617     }
1618 
1619     Shape shape = input_array.shape();
1620     if (!packed_shape) {
1621       packed_shape = std::make_unique<Shape>(shape);
1622     } else {
1623       CHECK(*packed_shape == shape) << "All input arrays to Pack operators "
1624                                        "must have the same shape. Input \""
1625                                     << input << "\" is different.";
1626     }
1627   }
1628 
1629   int axis = op->axis;
1630   if (axis < 0) {
1631     // Handle negative axis
1632     axis += packed_shape->dims().size() + 1;
1633   }
1634   packed_shape->mutable_dims()->insert(
1635       packed_shape->mutable_dims()->begin() + axis, op->inputs.size());
1636   output_array.copy_shape(*packed_shape);
1637 }
1638 
ProcessStridedSliceOperator(Model * model,StridedSliceOperator * op)1639 void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
1640   CHECK_GE(op->inputs.size(), 1);
1641   CHECK_EQ(op->outputs.size(), 1);
1642   auto& output_array = model->GetArray(op->outputs[0]);
1643   if (output_array.has_shape()) {
1644     // Shape already propagated
1645     return;
1646   }
1647 
1648   if (op->start_indices.empty() || op->stop_indices.empty() ||
1649       op->strides.empty()) {
1650     // ResolveStridedSliceAttributes has not run yet.
1651     return;
1652   }
1653 
1654   const auto& input_array = model->GetArray(op->inputs[0]);
1655   if (!input_array.has_shape()) {
1656     // Yield until input dims have been resolved.
1657     return;
1658   }
1659 
1660   if (op->ellipsis_mask != 0) {
1661     // Something like LOG_FIRST_N(WARNING, 10) would be preferable to reduce
1662     // log noise. However, the TensorFlow logging library does not appear to
1663     // support this.
1664     LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1665                  << "\". ellipsis_mask is not supported (mask="
1666                  << op->ellipsis_mask << ")";
1667     return;
1668   }
1669   if (op->new_axis_mask != 0) {
1670     LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1671                  << "\". new_axis_mask is not supported (mask="
1672                  << op->new_axis_mask << ")";
1673     return;
1674   }
1675 
1676   int num_input_axes = input_array.shape().dimensions_count();
1677   CHECK_LE(op->start_indices.size(), num_input_axes)
1678       << "StridedSlice op with output \"" << op->outputs[0]
1679       << "\", requires no more than " << num_input_axes << " start indices";
1680   CHECK_LE(op->stop_indices.size(), num_input_axes)
1681       << "StridedSlice op with output \"" << op->outputs[0]
1682       << "\", requires no more than " << num_input_axes << " stop indices";
1683   CHECK_LE(op->strides.size(), num_input_axes)
1684       << "StridedSlice op with output \"" << op->outputs[0]
1685       << "\", requires no more than " << num_input_axes << " strides";
1686   for (size_t i = 0; i < op->strides.size(); i++) {
1687     CHECK_NE(op->strides[i], 0) << "Strides must be non-zero. Axis " << i
1688                                 << " has stride=" << op->strides[i] << ".";
1689   }
1690 
1691   // Create output shape
1692   std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
1693 
1694   // Compute output shape
1695   for (int axis = 0; axis < num_input_axes; ++axis) {
1696     const auto strided_slice_params =
1697         tflite::strided_slice::BuildStridedSliceParams(
1698             op->begin_mask, op->end_mask, op->shrink_axis_mask,
1699             op->start_indices, op->stop_indices, op->strides);
1700     int start_index = tflite::strided_slice::StartForAxis(
1701         strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
1702     int stop_index = tflite::strided_slice::StopForAxis(
1703         strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
1704         start_index);
1705 
1706     int dim_size = std::ceil(static_cast<float>(stop_index - start_index) /
1707                              op->strides[axis]);
1708 
1709     CHECK_GT(dim_size, 0)
1710         << "Output size for an axis must be greater than 0. Axis " << axis
1711         << " computes to size " << dim_size
1712         << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1713     if (op->shrink_axis_mask & (1 << axis)) {
1714       CHECK_EQ(dim_size, 1)
1715           << "Output size for an axis must compute to 1 when shrinking an "
1716              "axis. Axis "
1717           << axis << " computes to size " << dim_size
1718           << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1719     } else {
1720       dims->push_back(dim_size);
1721     }
1722   }
1723 }
1724 
ProcessSqueezeOperator(Model * model,SqueezeOperator * op)1725 void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
1726   CHECK_EQ(op->inputs.size(), 1);
1727   CHECK_EQ(op->outputs.size(), 1);
1728 
1729   const auto& input_array = model->GetArray(op->inputs[0]);
1730 
1731   // Yield until input dims have been resolved.
1732   if (!input_array.has_shape()) return;
1733 
1734   auto& output_array = model->GetArray(op->outputs[0]);
1735   if (output_array.has_shape()) return;
1736 
1737   const std::vector<int>& input_dims = input_array.shape().dims();
1738   std::vector<int> output_dims;
1739 
1740   std::vector<int> squeeze_dims;
1741   const int input_num_dims = input_dims.size();
1742   squeeze_dims.reserve(op->squeeze_dims.size());
1743   for (int i : op->squeeze_dims) {
1744     squeeze_dims.push_back(i < 0 ? i + input_num_dims : i);
1745   }
1746   for (int i = 0; i < input_num_dims; ++i) {
1747     if (input_dims[i] != 1 ||
1748         (!squeeze_dims.empty() &&
1749          std::find(squeeze_dims.begin(), squeeze_dims.end(), i) ==
1750              squeeze_dims.end())) {
1751       output_dims.push_back(input_dims[i]);
1752     }
1753   }
1754   *output_array.mutable_shape()->mutable_dims() = output_dims;
1755 }
1756 
ProcessSvdfOperator(Model * model,SvdfOperator * op)1757 void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
1758   CHECK(op->inputs.size() == 4 || op->inputs.size() == 5);
1759   const auto& input_array = model->GetArray(op->inputs[0]);
1760   if (!input_array.has_shape()) return;
1761 
1762   auto& weights_feature_array = model->GetArray(op->inputs[1]);
1763   if (!weights_feature_array.has_shape()) return;
1764 
1765   const auto& weights_time_array = model->GetArray(op->inputs[2]);
1766   if (!weights_time_array.has_shape()) return;
1767 
1768   const bool has_bias = (op->inputs.size() == 5);
1769   if (has_bias) {
1770     const auto& bias_array = model->GetArray(op->inputs[3]);
1771     if (!bias_array.has_shape()) return;
1772   }
1773 
1774   const int batch_size = input_array.shape().dims()[0];
1775   const int num_units = weights_feature_array.shape().dims()[0];
1776   const int memory_size = weights_time_array.shape().dims()[1];
1777 
1778   auto& state_array = model->GetArray(op->outputs[0]);
1779   state_array.mutable_shape()->ReplaceDims(
1780       {batch_size, memory_size * num_units});
1781 
1782   auto& output_array = model->GetArray(op->outputs[1]);
1783   output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
1784 }
1785 
ProcessTransposeOperator(Model * model,TransposeOperator * op)1786 void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
1787   auto& output_array = model->GetArray(op->outputs[0]);
1788   if (output_array.has_shape()) {
1789     // We have already run
1790     return;
1791   }
1792 
1793   const auto& input_array = model->GetArray(op->inputs[0]);
1794   if (!input_array.has_shape()) {
1795     // Yield until input dims have been resolved.
1796     return;
1797   }
1798   const auto& input_shape = input_array.shape();
1799 
1800   auto& perm_array = model->GetArray(op->inputs[1]);
1801   if (!perm_array.has_shape()) {
1802     // Yield until permutation shape been resolved.
1803     return;
1804   }
1805   if (!perm_array.buffer) {
1806     // Yield until the permutation is constant
1807     return;
1808   }
1809   CHECK(perm_array.data_type == ArrayDataType::kInt32)
1810       << "Transpose permutation input must be int32";
1811 
1812   std::vector<int32> const& perm =
1813       perm_array.GetBuffer<ArrayDataType::kInt32>().data;
1814   CHECK_EQ(perm.size(), input_shape.dimensions_count())
1815       << "Transpose permutation input " << op->inputs[1]
1816       << " must be same length as input dimensions";
1817   std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims();
1818   for (size_t i = 0; i < perm.size(); i++) {
1819     int axis = perm[i];
1820     CHECK_GE(axis, 0);
1821     CHECK_LT(axis, input_shape.dimensions_count());
1822     output_dims->push_back(input_shape.dims(axis));
1823   }
1824 }
1825 
1826 template <typename Op>
ProcessArgMinMaxOperator(Model * model,Op * op)1827 void ProcessArgMinMaxOperator(Model* model, Op* op) {
1828   CHECK_EQ(op->inputs.size(), 2);
1829   const auto& input_array = model->GetArray(op->inputs[0]);
1830   // Yield until input dims have been resolved.
1831   if (!input_array.has_shape()) {
1832     return;
1833   }
1834 
1835   const Array& axis_array = model->GetArray(op->inputs[1]);
1836   // Yield until input axis array shape has been resolved.
1837   if (!axis_array.has_shape()) {
1838     return;
1839   }
1840 
1841   const std::vector<int>& input_dims = input_array.shape().dims();
1842 
1843   CHECK(axis_array.data_type == ArrayDataType::kInt32 ||
1844         axis_array.data_type == ArrayDataType::kInt64)
1845       << "axis_array must be int32, int64";
1846 
1847   CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
1848       << "Axis array must be scalar.";
1849 
1850   int64_t axis;
1851   if (axis_array.data_type == ArrayDataType::kInt32) {
1852     axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1853   } else {
1854     axis = axis_array.GetBuffer<ArrayDataType::kInt64>().data[0];
1855   }
1856 
1857   std::vector<int> output_dims;
1858 
1859   output_dims.reserve(input_dims.size() - 1);
1860   for (size_t i = 0; i < input_dims.size(); ++i) {
1861     if (static_cast<int>(i) != axis) {
1862       output_dims.push_back(input_dims[i]);
1863     }
1864   }
1865 
1866   const std::string& output_name = op->outputs[0];
1867   auto& output_array = model->GetArray(output_name);
1868   if (output_array.has_shape()) {
1869     return;
1870   }
1871   *output_array.mutable_shape()->mutable_dims() = output_dims;
1872 }
1873 
ProcessSparseToDenseOperator(Model * model,SparseToDenseOperator * op)1874 void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
1875   CHECK_EQ(op->inputs.size(), 4);
1876 
1877   const Array& output_shape_array = model->GetArray(op->inputs[1]);
1878   if (!output_shape_array.has_shape()) return;
1879   CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
1880 
1881   // Output should not go over four dimensions.
1882   CHECK_LE(output_shape_array.shape().dims(0), 4);
1883 
1884   const std::string& output_name = op->outputs[0];
1885   Array& output_array = model->GetArray(output_name);
1886   if (output_array.has_shape()) return;
1887 
1888   CHECK(output_shape_array.data_type == ArrayDataType::kInt32 ||
1889         output_shape_array.data_type == ArrayDataType::kInt64);
1890   if (output_shape_array.data_type == ArrayDataType::kInt32) {
1891     *output_array.mutable_shape()->mutable_dims() =
1892         output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1893   } else {
1894     const std::vector<int64_t>& output_shape_data =
1895         output_shape_array.GetBuffer<ArrayDataType::kInt64>().data;
1896     // explicitly cast elements to int in order to avoid MSVC warnings about
1897     // narrowing conversion.
1898     std::transform(
1899         output_shape_data.begin(), output_shape_data.end(),
1900         std::back_inserter(*output_array.mutable_shape()->mutable_dims()),
1901         [](const int64_t dim) { return static_cast<int>(dim); });
1902   }
1903 }
1904 
ProcessTileOperator(Model * model,TensorFlowTileOperator * op)1905 void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
1906   CHECK_EQ(op->inputs.size(), 2);
1907   CHECK_EQ(op->outputs.size(), 1);
1908 
1909   auto& output_array = model->GetArray(op->outputs[0]);
1910   if (output_array.has_shape()) {
1911     // We have already run.
1912     return;
1913   }
1914 
1915   const auto& input_array = model->GetArray(op->inputs[0]);
1916   if (!input_array.has_shape()) {
1917     // Yield until input dims have been resolved.
1918     return;
1919   }
1920   const auto& input_shape = input_array.shape();
1921 
1922   auto& multiples_array = model->GetArray(op->inputs[1]);
1923   if (!multiples_array.has_shape()) {
1924     // Yield until multiples shape been resolved.
1925     return;
1926   }
1927   if (!multiples_array.buffer) {
1928     // Yield until the multiples is constant.
1929     return;
1930   }
1931   CHECK(multiples_array.data_type == ArrayDataType::kInt32)
1932       << "Tile multiples input must be int32";
1933 
1934   std::vector<int32> const& multiples =
1935       multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
1936   CHECK_EQ(multiples.size(), input_shape.dimensions_count())
1937       << "Tile multiples input " << op->inputs[1]
1938       << " must be same length as input dimensions";
1939 
1940   auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1941   mutable_dims->resize(multiples.size());
1942   for (size_t i = 0; i < mutable_dims->size(); ++i) {
1943     (*mutable_dims)[i] = input_shape.dims(i) * multiples[i];
1944   }
1945 }
1946 
ProcessOneHotOperator(Model * model,OneHotOperator * op)1947 void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
1948   CHECK_EQ(op->inputs.size(), 4);
1949   CHECK_EQ(op->outputs.size(), 1);
1950   auto& output_array = model->GetArray(op->outputs[0]);
1951   if (output_array.has_shape()) {
1952     // Shape already propagated
1953     return;
1954   }
1955 
1956   // Yield until indices dims have been resolved.
1957   const auto& indices_array =
1958       model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]);
1959   if (!indices_array.has_shape()) {
1960     return;
1961   }
1962 
1963   // Yield until depth is constant and dims have been resolved.
1964   if (!IsConstantParameterArray(*model,
1965                                 op->inputs[OneHotOperator::DEPTH_INPUT])) {
1966     return;
1967   }
1968   const auto& depth_array =
1969       model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]);
1970   if (!depth_array.has_shape()) {
1971     return;
1972   }
1973 
1974   CHECK(depth_array.data_type == ArrayDataType::kInt32)
1975       << "Depth array must be int32.";
1976   CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1)
1977       << "Depth array must be scalar.";
1978 
1979   const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1980   CHECK_GE(depth, 0) << "Depth must be non-negative.";
1981 
1982   const int indices_dims = indices_array.shape().dimensions_count();
1983   const int output_dims = indices_dims + 1;
1984   const int axis = op->axis == -1 ? indices_dims : op->axis;
1985   CHECK_GE(axis, 0) << "Resolved axis must be non-negative.";
1986 
1987   auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1988   mutable_dims->resize(output_dims);
1989   for (int i = 0; i < output_dims; ++i) {
1990     int dim = 0;
1991     if (i < axis) {
1992       dim = indices_array.shape().dims(i);
1993     } else if (i == axis) {
1994       dim = depth;
1995     } else {
1996       dim = indices_array.shape().dims(i - 1);
1997     }
1998     (*mutable_dims)[i] = dim;
1999   }
2000 }
2001 
ProcessUnpackOperator(Model * model,UnpackOperator * op)2002 void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
2003   CHECK_EQ(op->inputs.size(), 1);
2004   const auto& input_array = model->GetArray(op->inputs[0]);
2005   // Yield until input dims have been resolved.
2006   if (!input_array.has_shape()) {
2007     return;
2008   }
2009 
2010   const std::vector<int>& input_dims = input_array.shape().dims();
2011   std::vector<int> output_dims;
2012 
2013   output_dims.reserve(input_dims.size() - 1);
2014   for (size_t i = 0; i < input_dims.size(); ++i) {
2015     if (static_cast<int>(i) != op->axis) {
2016       output_dims.push_back(input_dims[i]);
2017     }
2018   }
2019   for (const std::string& output_name : op->outputs) {
2020     auto& output_array = model->GetArray(output_name);
2021     if (output_array.has_shape()) {
2022       return;
2023     }
2024     *output_array.mutable_shape()->mutable_dims() = output_dims;
2025   }
2026 }
2027 
ProcessMirrorPadOperator(Model * model,MirrorPadOperator * op)2028 void ProcessMirrorPadOperator(Model* model, MirrorPadOperator* op) {
2029   CHECK_EQ(op->inputs.size(), 2);
2030   const auto& input_array = model->GetArray(op->inputs[0]);
2031   const auto& padding_matrix = model->GetArray(op->inputs[1]);
2032 
2033   // Yield until input dims have been resolved.
2034   if (!input_array.has_shape()) {
2035     return;
2036   }
2037 
2038   auto& output_array = model->GetArray(op->outputs[0]);
2039   // If output already computed or padding matrix is non
2040   // const then return.
2041   if (output_array.has_shape() ||
2042       !IsConstantParameterArray(*model, op->inputs[1])) {
2043     return;
2044   }
2045   Shape output_shape = input_array.shape();
2046   std::vector<int>& dims = *output_shape.mutable_dims();
2047 
2048   std::vector<int64_t> padding;
2049   if (padding_matrix.data_type == ArrayDataType::kInt32) {
2050     const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt32>().data;
2051     for (auto elem : data) {
2052       padding.push_back(static_cast<int64_t>(elem));
2053     }
2054   } else if (padding_matrix.data_type == ArrayDataType::kInt64) {
2055     const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt64>().data;
2056     for (auto elem : data) {
2057       padding.push_back(elem);
2058     }
2059   } else {
2060     CHECK(padding_matrix.data_type == ArrayDataType::kInt64 ||
2061           padding_matrix.data_type == ArrayDataType::kInt32);
2062   }
2063   CHECK_EQ(padding_matrix.shape().dimensions_count(), 2);
2064   CHECK_EQ(input_array.shape().dimensions_count(),
2065            padding_matrix.shape().dims(0));
2066   for (int i = 0; i < input_array.shape().dimensions_count(); ++i) {
2067     dims[i] += padding[i * 2] + padding[i * 2 + 1];
2068   }
2069 
2070   output_array.copy_shape(output_shape);
2071 }
2072 
ProcessUniqueOperator(Model * model,UniqueOperator * op)2073 void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
2074   const auto& input_array = model->GetArray(op->inputs[0]);
2075   // We have 2 outputs, the shape of the index tensor, is the same size
2076   // as the input array. The unique values tensor, is unknown until runtime.
2077   CHECK_EQ(op->outputs.size(), 2);
2078   auto& idx_output_array = model->GetArray(op->outputs[1]);
2079 
2080   // Yield until input dims have been resolved, or output already computed
2081   if (!input_array.has_shape() || idx_output_array.has_shape()) {
2082     return;
2083   }
2084   idx_output_array.copy_shape(input_array.shape());
2085 }
2086 
ProcessMatrixDiagOperator(Model * model,MatrixDiagOperator * op)2087 void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
2088   CHECK_EQ(op->inputs.size(), 1);
2089   CHECK_EQ(op->outputs.size(), 1);
2090   auto& input_array = model->GetArray(op->inputs[0]);
2091   auto& output_array = model->GetArray(op->outputs[0]);
2092   // The input array must have a shape in order to proceed. Also,
2093   // bail out if the output shape has already been calculated.
2094   if (!input_array.has_shape() || output_array.has_shape()) {
2095     // We have already run
2096     return;
2097   }
2098   // Get the input_shape
2099   Shape* mutable_shape = input_array.mutable_shape();
2100   std::vector<int>* dims = mutable_shape->mutable_dims();
2101   int dims_size = dims->size();
2102   // Scalars are not allowed.
2103   CHECK_GT(dims_size, 0);
2104   int last_dim = (*dims)[dims_size - 1];
2105   dims->push_back(last_dim);
2106   output_array.copy_shape(*mutable_shape);
2107 }
2108 
ProcessMatrixSetDiagOperator(Model * model,MatrixSetDiagOperator * op)2109 void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
2110   CHECK_EQ(op->inputs.size(), 2);
2111   CHECK_EQ(op->outputs.size(), 1);
2112   auto& input_array = model->GetArray(op->inputs[0]);
2113   auto& output_array = model->GetArray(op->outputs[0]);
2114   // The shape of the input array must be known because that will
2115   // be the shape of the output array.
2116   if (!input_array.has_shape() || !output_array.has_shape()) {
2117     // We have already run
2118     return;
2119   }
2120 
2121   output_array.copy_shape(input_array.shape());
2122 }
2123 
ProcessScatterNdOperator(Model * model,ScatterNdOperator * op)2124 void ProcessScatterNdOperator(Model* model, ScatterNdOperator* op) {
2125   CHECK_EQ(op->inputs.size(), 3);
2126   CHECK_EQ(op->outputs.size(), 1);
2127   auto& shape_array = model->GetArray(op->inputs[2]);
2128   auto& output_array = model->GetArray(op->outputs[0]);
2129 
2130   if (!shape_array.has_shape()) {
2131     // Yield until dims shape been resolved.
2132     return;
2133   }
2134   if (!shape_array.buffer) {
2135     // Yield until the dims are constant
2136     return;
2137   }
2138   CHECK(shape_array.data_type == ArrayDataType::kInt32) << "dims must be int32";
2139 
2140   std::vector<int32> const& dims =
2141       shape_array.GetBuffer<ArrayDataType::kInt32>().data;
2142   *(output_array.mutable_shape()->mutable_dims()) = dims;
2143 }
2144 
2145 }  // namespace
2146 
Run(Model * model,std::size_t op_index,bool * modified)2147 ::tensorflow::Status PropagateFixedSizes::Run(Model* model,
2148                                               std::size_t op_index,
2149                                               bool* modified) {
2150   *modified = false;
2151   auto it = model->operators.begin() + op_index;
2152   auto* op = it->get();
2153   std::unordered_map<std::string, std::vector<int>> old_output_dims;
2154   for (const auto& output : op->outputs) {
2155     if (model->GetArray(output).has_shape()) {
2156       old_output_dims[output] = model->GetArray(output).shape().dims();
2157     }
2158   }
2159 
2160   switch (op->type) {
2161     case OperatorType::kAbs:
2162     case OperatorType::kBatchNormalization:
2163     case OperatorType::kL2Normalization:
2164     case OperatorType::kDequantize:
2165     case OperatorType::kElu:
2166     case OperatorType::kHardSwish:
2167     case OperatorType::kRelu:
2168     case OperatorType::kRelu1:
2169     case OperatorType::kRelu6:
2170     case OperatorType::kPRelu:
2171     case OperatorType::kLeakyRelu:
2172     case OperatorType::kSoftmax:
2173     case OperatorType::kLogSoftmax:
2174     case OperatorType::kLog:
2175     case OperatorType::kLogistic:
2176     case OperatorType::kTanh:
2177     case OperatorType::kLocalResponseNormalization:
2178     case OperatorType::kIdentity:
2179     case OperatorType::kFakeQuant:
2180     case OperatorType::kNeg:
2181     case OperatorType::kRsqrt:
2182     case OperatorType::kSqrt:
2183     case OperatorType::kSquare:
2184     case OperatorType::kAll:
2185     case OperatorType::kAssert:
2186     case OperatorType::kCast:
2187     case OperatorType::kFloor:
2188     case OperatorType::kCeil:
2189     case OperatorType::kRound:
2190     case OperatorType::kExp:
2191     case OperatorType::kSin:
2192     case OperatorType::kCos:
2193     case OperatorType::kLogicalAnd:
2194     case OperatorType::kLogicalNot:
2195     case OperatorType::kLogicalOr:
2196     case OperatorType::kZerosLike:
2197     case OperatorType::kReverseV2:
2198     case OperatorType::kReverseSequence:
2199       ProcessSimpleOperator(model, op, 0);
2200       break;
2201     case OperatorType::kGather:
2202       ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
2203       break;
2204     case OperatorType::kGatherNd:
2205       ProcessGatherNdOperator(model, static_cast<GatherNdOperator*>(op));
2206       break;
2207     case OperatorType::kTopK_V2:
2208       ProcessTopkV2Operator(model, static_cast<TopKV2Operator*>(op));
2209       break;
2210     case OperatorType::kAdd:
2211     case OperatorType::kSub:
2212     case OperatorType::kMul:
2213     case OperatorType::kDiv:
2214     case OperatorType::kFloorDiv:
2215     case OperatorType::kFloorMod:
2216     case OperatorType::kLess:
2217     case OperatorType::kLessEqual:
2218     case OperatorType::kGreater:
2219     case OperatorType::kMaximum:  //  Element-wise Maximum
2220     case OperatorType::kMinimum:  //  Element-wise Minimum
2221     case OperatorType::kGreaterEqual:
2222     case OperatorType::kEqual:
2223     case OperatorType::kNotEqual:
2224     case OperatorType::kPow:
2225     case OperatorType::kSquaredDifference:
2226       ProcessSimpleBinaryOperator(model, op);
2227       break;
2228     case OperatorType::kAddN:
2229       ProcessAddNOperator(model, op);
2230       break;
2231     case OperatorType::kConv:
2232       ProcessConvOperator(model, static_cast<ConvOperator*>(op));
2233       break;
2234     case OperatorType::kTransposeConv:
2235       ProcessTransposeConvOperator(model,
2236                                    static_cast<TransposeConvOperator*>(op));
2237       break;
2238     case OperatorType::kDepthwiseConv:
2239       ProcessDepthwiseConvOperator(model,
2240                                    static_cast<DepthwiseConvOperator*>(op));
2241       break;
2242     case OperatorType::kDepthToSpace:
2243       ProcessDepthToSpaceOperator(model,
2244                                   static_cast<DepthToSpaceOperator*>(op));
2245       break;
2246     case OperatorType::kSpaceToDepth:
2247       ProcessSpaceToDepthOperator(model,
2248                                   static_cast<SpaceToDepthOperator*>(op));
2249       break;
2250     case OperatorType::kFill:
2251       CHECK_EQ(op->inputs.size(), 2);
2252       ProcessOpWithShapeInput(model, op);
2253       break;
2254     case OperatorType::kFullyConnected:
2255       ProcessFullyConnectedOperator(model,
2256                                     static_cast<FullyConnectedOperator*>(op));
2257       break;
2258     case OperatorType::kReshape:
2259       ProcessTensorFlowReshapeOperator(
2260           model, static_cast<TensorFlowReshapeOperator*>(op));
2261       break;
2262     case OperatorType::kAveragePool:
2263       ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
2264       break;
2265     case OperatorType::kMaxPool:
2266       ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
2267       break;
2268     case OperatorType::kL2Pool:
2269       ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
2270       break;
2271     case OperatorType::kReduceMin:  //  Reduction Min
2272     case OperatorType::kReduceMax:  //  Reduction Max
2273     case OperatorType::kSum:
2274     case OperatorType::kReduceProd:
2275     case OperatorType::kMean:
2276     case OperatorType::kAny:
2277       ProcessTensorFlowReductionOperator(model, op);
2278       break;
2279     case OperatorType::kSelect:
2280       ProcessSelectOperator(model, static_cast<SelectOperator*>(op));
2281       break;
2282     case OperatorType::kSlice:
2283       ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
2284       break;
2285 
2286     case OperatorType::kSwitch:
2287       // We can't know the sizes of the outputs until we have resolved the
2288       // predicate, and once we have resolved the predicate, the whole
2289       // Switch node will get resolved away.
2290       // See ResolveTensorFlowSwitch.
2291       break;
2292     case OperatorType::kMerge:
2293       // No need to bother resolving TensorFlow Merge ops: other graph
2294       // transformations will remove them anyway.
2295       // See ResolveTensorFlowMerge.
2296       break;
2297     case OperatorType::kSplit:
2298       ProcessTensorFlowSplitOperator(model,
2299                                      static_cast<TensorFlowSplitOperator*>(op));
2300       break;
2301     case OperatorType::kSplitV:
2302       ProcessTensorFlowSplitVOperator(
2303           model, static_cast<TensorFlowSplitVOperator*>(op));
2304       break;
2305     case OperatorType::kSqueeze:
2306       ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
2307       break;
2308     case OperatorType::kConcat:
2309     case OperatorType::kConcatV2:
2310       // Unimplemented, hopefully another graph transformation will
2311       // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
2312       // will resolve this node to a DepthConcatenation, or else we have
2313       // a more general non-depth concatenation that will hopefully be dropped,
2314       // or else at the moment we will abort.
2315       break;
2316     case OperatorType::kExpandDims:
2317       // Yield until ExpandDims is converted to Reshape
2318       break;
2319     case OperatorType::kRange:
2320       ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
2321       break;
2322     case OperatorType::kRank:
2323       ProcessRankOperator(model, static_cast<TensorFlowRankOperator*>(op));
2324       break;
2325     case OperatorType::kShape:
2326       ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
2327       break;
2328     case OperatorType::kPack:
2329       ProcessPackOperator(model, static_cast<PackOperator*>(op));
2330       break;
2331     case OperatorType::kReorderAxes:
2332       ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
2333       break;
2334     case OperatorType::kConcatenation:
2335       ProcessConcatenationOperator(model,
2336                                    static_cast<ConcatenationOperator*>(op));
2337       break;
2338     case OperatorType::kResizeBilinear:
2339       ProcessResizeBilinearOperator(model,
2340                                     static_cast<ResizeBilinearOperator*>(op));
2341       break;
2342     case OperatorType::kResizeNearestNeighbor:
2343       ProcessResizeNearestNeighborOperator(
2344           model, static_cast<ResizeNearestNeighborOperator*>(op));
2345       break;
2346     case OperatorType::kUnidirectionalSequenceLstm:
2347       ProcessUnidirectionalSequenceLstmOperator(
2348           model, static_cast<UnidirectionalSequenceLstmOperator*>(op));
2349       break;
2350     case OperatorType::kUnidirectionalSequenceRnn:
2351       ProcessUnidirectionalSequenceRnnOperator(
2352           model, static_cast<UnidirectionalSequenceRnnOperator*>(op));
2353       break;
2354     case OperatorType::kBidirectionalSequenceLstm:
2355       ProcessBidirectionalSequenceLstmOperator(
2356           model, static_cast<BidirectionalSequenceLstmOperator*>(op));
2357       break;
2358     case OperatorType::kBidirectionalSequenceRnn:
2359       ProcessBidirectionalSequenceRnnOperator(
2360           model, static_cast<BidirectionalSequenceRnnOperator*>(op));
2361       break;
2362     case OperatorType::kLstmCell:
2363       ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
2364       break;
2365     case OperatorType::kBatchMatMul:
2366     case OperatorType::kMatMul:
2367       // MatMul operators are converted to FullyConnected, after which their
2368       // shapes are propagated.
2369       break;
2370     case OperatorType::kSpaceToBatchND:
2371       ProcessSpaceToBatchNDOperator(model,
2372                                     static_cast<SpaceToBatchNDOperator*>(op));
2373       break;
2374     case OperatorType::kBatchToSpaceND:
2375       ProcessBatchToSpaceNDOperator(model,
2376                                     static_cast<BatchToSpaceNDOperator*>(op));
2377       break;
2378     case OperatorType::kPad:
2379       ProcessPadOperator(model, static_cast<PadOperator*>(op));
2380       break;
2381     case OperatorType::kPadV2:
2382       ProcessPadV2Operator(model, static_cast<PadV2Operator*>(op));
2383       break;
2384     case OperatorType::kStridedSlice:
2385       ProcessStridedSliceOperator(model,
2386                                   static_cast<StridedSliceOperator*>(op));
2387       break;
2388     case OperatorType::kArgMax:
2389       ProcessArgMinMaxOperator<ArgMaxOperator>(
2390           model, static_cast<ArgMaxOperator*>(op));
2391       break;
2392     case OperatorType::kArgMin:
2393       ProcessArgMinMaxOperator<ArgMinOperator>(
2394           model, static_cast<ArgMinOperator*>(op));
2395       break;
2396     case OperatorType::kUnsupported: {
2397       const auto* unsupported_op =
2398           static_cast<TensorFlowUnsupportedOperator*>(op);
2399       // Attribute can be not specified, ignore it.
2400       if (unsupported_op->output_shapes.size() < op->outputs.size()) {
2401         return ::tensorflow::OkStatus();
2402       }
2403       for (size_t i = 0; i < op->outputs.size(); ++i) {
2404         const std::string& output = op->outputs[i];
2405         model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i));
2406       }
2407       break;
2408     }
2409     case OperatorType::kSvdf:
2410       ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
2411       break;
2412     case OperatorType::kTranspose:
2413       ProcessTransposeOperator(model, static_cast<TransposeOperator*>(op));
2414       break;
2415     case OperatorType::kDynamicPartition:
2416     case OperatorType::kDynamicStitch:
2417       // DynamicPartition/DynamicStitch are currently only supported for
2418       // transforms that remove them, so we avoid propagating shapes through
2419       // them and let things settle once they've been removed.
2420       break;
2421     case OperatorType::kRandomUniform:
2422       CHECK_EQ(op->inputs.size(), 1);
2423       ProcessOpWithShapeInput(model, op);
2424       break;
2425     case OperatorType::kSparseToDense:
2426       ProcessSparseToDenseOperator(model,
2427                                    static_cast<SparseToDenseOperator*>(op));
2428       break;
2429     case OperatorType::kTile:
2430       ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
2431       break;
2432       break;
2433     case OperatorType::kOneHot:
2434       ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
2435       break;
2436     case OperatorType::kUnpack:
2437       ProcessUnpackOperator(model, static_cast<UnpackOperator*>(op));
2438       break;
2439     case OperatorType::kMirrorPad:
2440       ProcessMirrorPadOperator(model, static_cast<MirrorPadOperator*>(op));
2441       break;
2442     case OperatorType::kUnique:
2443       ProcessUniqueOperator(model, static_cast<UniqueOperator*>(op));
2444       break;
2445     case OperatorType::kWhere:
2446       // The size of the output can only be known after evaluating the cond
2447       // tensor. Ignore shape propagation here and defer that to the
2448       // interpreter.
2449       break;
2450     case OperatorType::kMatrixDiag:
2451       ProcessMatrixDiagOperator(model, static_cast<MatrixDiagOperator*>(op));
2452       break;
2453     case OperatorType::kMatrixSetDiag:
2454       ProcessMatrixSetDiagOperator(model,
2455                                    static_cast<MatrixSetDiagOperator*>(op));
2456       break;
2457     case OperatorType::kCTCBeamSearchDecoder:
2458       // The sizes of the outputs are only known in runtime based on the input.
2459       // Ignore shape propagation here and defer that to the interpreter.
2460       break;
2461     case OperatorType::kMatrixSetDiagV2:
2462       // MatrixSetDiagV2 operators are converted to MatrixSetDiag,
2463       // after which their shapes are propagated.
2464       break;
2465     case OperatorType::kMatrixDiagV2:
2466       // MatrixDiagV2 operators are converted to MatrixDiag, after which their
2467       // shapes are propagated.
2468       break;
2469     case OperatorType::kMatrixDiagV3:
2470       // MatrixDiagV3 operators are converted to MatrixDiag, after which their
2471       // shapes are propagated.
2472       break;
2473     case OperatorType::kMatrixSetDiagV3:
2474       // MatrixSetDiagV3 operators are converted to MatrixSetDiag, after which
2475       // their shapes are propagated.
2476       break;
2477     case OperatorType::kSegmentSum:
2478       break;
2479     case OperatorType::kScatterNd:
2480       ProcessScatterNdOperator(model, static_cast<ScatterNdOperator*>(op));
2481       break;
2482     default:
2483       // Unimplemented, another graph transformation should drop it.
2484       LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
2485   }
2486 
2487   // Return true if any output dim changed, false if none changed.
2488   // Assumption: no transformation clears an output shape, they only add shapes.
2489   for (const auto& output : op->outputs) {
2490     if (model->GetArray(output).has_shape() &&
2491         (old_output_dims[output] != model->GetArray(output).shape().dims())) {
2492       AddMessageF("Set shape of %s to [%s]", output,
2493                   absl::StrJoin(model->GetArray(output).shape().dims(), ","));
2494       *modified = true;
2495       return ::tensorflow::OkStatus();
2496     }
2497   }
2498   return ::tensorflow::OkStatus();
2499 }
2500 
2501 }  // namespace toco
2502