xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/array_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <ostream>
18 
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/kernel_shape_util.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/util/mirror_pad_mode.h"
29 #include "tensorflow/core/util/padding.h"
30 #include "tensorflow/core/util/strided_slice_op.h"
31 #include "tensorflow/core/util/tensor_format.h"
32 
33 namespace tensorflow {
34 
35 using shape_inference::DimensionHandle;
36 using shape_inference::InferenceContext;
37 using shape_inference::ShapeHandle;
38 using shape_inference::UnchangedShape;
39 
40 namespace {
41 
GetAxisForPackAndUnpack(InferenceContext * c,int32_t rank_after_pack,int32 * axis)42 Status GetAxisForPackAndUnpack(InferenceContext* c, int32_t rank_after_pack,
43                                int32* axis) {
44   TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
45   if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
46     return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
47                                    -1 * rank_after_pack, ",", rank_after_pack,
48                                    ")");
49   }
50   if (*axis < 0) *axis = (rank_after_pack + *axis);
51   return OkStatus();
52 }
53 
54 template <typename T>
AsInt64(const Tensor * tensor,int64_t num_elements)55 std::vector<int64_t> AsInt64(const Tensor* tensor, int64_t num_elements) {
56   std::vector<int64_t> ret(num_elements);
57   auto data = tensor->vec<T>();
58   for (int64_t i = 0; i < num_elements; ++i) {
59     ret[i] = data(i);
60   }
61   return ret;
62 }
63 
64 template <typename T>
PadKnown(InferenceContext * c,ShapeHandle input,const Tensor * paddings_t,int64_t num_dims)65 Status PadKnown(InferenceContext* c, ShapeHandle input,
66                 const Tensor* paddings_t, int64_t num_dims) {
67   // paddings_t is known.
68   std::vector<DimensionHandle> dims(num_dims);
69   auto paddings_data = paddings_t->matrix<T>();
70   for (int64_t i = 0; i < num_dims; ++i) {
71     const T pad0 = paddings_data(i, 0);
72     const T pad1 = paddings_data(i, 1);
73     if (pad0 < 0 || pad1 < 0) {
74       return errors::InvalidArgument("Paddings must be non-negative");
75     }
76     TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i]));
77   }
78   c->set_output(0, c->MakeShape(dims));
79   return OkStatus();
80 }
81 
PadShapeFn(InferenceContext * c)82 Status PadShapeFn(InferenceContext* c) {
83   // Paddings is a matrix of [input_rank, 2].
84   ShapeHandle paddings;
85   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
86   DimensionHandle unused;
87   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(paddings, 1), 2, &unused));
88 
89   // n_dim and input.rank are equivalent.
90   ShapeHandle input = c->input(0);
91   DimensionHandle n_dim = c->Dim(paddings, 0);
92   if (c->ValueKnown(n_dim)) {
93     TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(n_dim), &input));
94   } else if (c->RankKnown(input)) {
95     TF_RETURN_IF_ERROR(c->WithValue(n_dim, c->Rank(input), &n_dim));
96   }
97 
98   const Tensor* paddings_t = c->input_tensor(1);
99 
100   // paddings_t is unknown
101   if (paddings_t == nullptr) {
102     if (c->ValueKnown(n_dim)) {
103       // Make output with n_dim unknown dims.
104       c->set_output(0, c->UnknownShapeOfRank(c->Value(n_dim)));
105     } else {
106       c->set_output(0, c->UnknownShape());
107     }
108     return OkStatus();
109   }
110 
111   const int64_t num_dims = paddings_t->shape().dim_size(0);
112   TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input));
113   TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim));
114 
115   if (paddings_t->dtype() == DT_INT32) {
116     return PadKnown<int32>(c, input, paddings_t, num_dims);
117   } else {
118     return PadKnown<int64_t>(c, input, paddings_t, num_dims);
119   }
120 }
121 
TransposeShapeFn(InferenceContext * c)122 Status TransposeShapeFn(InferenceContext* c) {
123   ShapeHandle input = c->input(0);
124   ShapeHandle perm_shape = c->input(1);
125   const Tensor* perm = c->input_tensor(1);
126   DimensionHandle perm_elems = c->NumElements(perm_shape);
127   // If we don't have rank information on the input or value information on
128   // perm we can't return any shape information, otherwise we have enough
129   // information to at least find the rank of the output.
130   if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) && perm == nullptr) {
131     c->set_output(0, c->UnknownShape());
132     return OkStatus();
133   }
134 
135   // Find our value of the rank.
136   int64_t rank;
137   if (c->RankKnown(input)) {
138     rank = c->Rank(input);
139   } else if (c->ValueKnown(perm_elems)) {
140     rank = c->Value(perm_elems);
141   } else {
142     rank = perm->NumElements();
143   }
144   if (!c->RankKnown(input) && rank < 2) {
145     // A permutation array containing a single element is ambiguous. It could
146     // indicate either a scalar or a 1-dimensional array, both of which the
147     // transpose op returns unchanged.
148     c->set_output(0, input);
149     return OkStatus();
150   }
151 
152   std::vector<DimensionHandle> dims;
153   dims.resize(rank);
154   TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
155   // Ensure that perm is a vector and has rank elements.
156   TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape));
157   TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems));
158 
159   // If we know the rank of the input and the value of perm, we can return
160   // all shape information, otherwise we can only return rank information,
161   // but no information for the dimensions.
162   if (perm != nullptr) {
163     std::vector<int64_t> data;
164     if (perm->dtype() == DT_INT32) {
165       data = AsInt64<int32>(perm, rank);
166     } else {
167       data = AsInt64<int64_t>(perm, rank);
168     }
169 
170     for (int32_t i = 0; i < rank; ++i) {
171       int64_t in_idx = data[i];
172       if (in_idx >= rank || in_idx <= -rank) {
173         return errors::InvalidArgument("perm dim ", in_idx,
174                                        " is out of range of input rank ", rank);
175       }
176       dims[i] = c->Dim(input, in_idx);
177     }
178   } else {
179     for (int i = 0; i < rank; ++i) {
180       dims[i] = c->UnknownDim();
181     }
182   }
183 
184   c->set_output(0, c->MakeShape(dims));
185   return OkStatus();
186 }
187 
SetOutputShapeForReshape(InferenceContext * c)188 Status SetOutputShapeForReshape(InferenceContext* c) {
189   ShapeHandle in = c->input(0);
190   ShapeHandle out;
191   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
192 
193   if (!c->RankKnown(out)) {
194     // We have no information about the shape of the output.
195     c->set_output(0, out);
196     return OkStatus();
197   }
198   if (c->RankKnown(in)) {
199     // We don't know the number of output elements, but we can try to infer
200     // the missing dimension.
201     bool too_many_unknown = false;
202     int32_t out_unknown_idx = -1;
203 
204     DimensionHandle known_out_elems = c->NumElements(out);
205     if (!c->ValueKnown(known_out_elems)) {
206       known_out_elems = c->MakeDim(1);
207       for (int32_t i = 0; i < c->Rank(out); ++i) {
208         DimensionHandle dim = c->Dim(out, i);
209         if (!c->ValueKnown(dim)) {
210           if (out_unknown_idx >= 0) {
211             too_many_unknown = true;
212             break;
213           }
214           out_unknown_idx = i;
215         } else {
216           TF_RETURN_IF_ERROR(
217               c->Multiply(known_out_elems, dim, &known_out_elems));
218         }
219       }
220     }
221     int32_t in_unknown_idx = -1;
222     DimensionHandle known_in_elems = c->NumElements(in);
223     if (!c->ValueKnown(known_in_elems)) {
224       known_in_elems = c->MakeDim(1);
225       for (int32_t i = 0; i < c->Rank(in); ++i) {
226         DimensionHandle dim = c->Dim(in, i);
227         if (!c->ValueKnown(dim)) {
228           if (in_unknown_idx >= 0) {
229             too_many_unknown = true;
230             break;
231           }
232           in_unknown_idx = i;
233         } else {
234           TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems));
235         }
236       }
237     }
238 
239     if (!too_many_unknown) {
240       if (in_unknown_idx < 0 && out_unknown_idx < 0) {
241         // Just check that the dimensions match.
242         if (c->Value(known_in_elems) != c->Value(known_out_elems)) {
243           return errors::InvalidArgument(
244               "Cannot reshape a tensor with ", c->DebugString(known_in_elems),
245               " elements to shape ", c->DebugString(out), " (",
246               c->DebugString(known_out_elems), " elements)");
247         }
248       } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 &&
249                  c->Value(known_out_elems) > 0) {
250         // Input fully known, infer the one missing output dim
251         DimensionHandle inferred_dim;
252         TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems),
253                                      true /* evenly_divisible */,
254                                      &inferred_dim));
255         TF_RETURN_IF_ERROR(
256             c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out));
257 
258       } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 &&
259                  c->Value(known_in_elems) != 0) {
260         // Output fully known, infer the one missing input dim
261         DimensionHandle inferred_dim;
262         TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems),
263                                      true /* evenly_divisible */,
264                                      &inferred_dim));
265         DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
266         TF_RETURN_IF_ERROR(
267             c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim));
268       } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) {
269         // Exactly one unknown dimension in both input and output. These 2 are
270         // equal iff the known elements are equal.
271         if (c->Value(known_in_elems) == c->Value(known_out_elems)) {
272           DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
273           TF_RETURN_IF_ERROR(
274               c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out));
275         }
276       }
277     }
278   }
279   c->set_output(0, out);
280   return OkStatus();
281 }
282 
283 }  // namespace
284 
285 REGISTER_OP("ParallelConcat")
286     .Input("values: N * T")
287     .Output("output: T")
288     .Attr("N: int >= 1")
289     .Attr("T: type")
290     .Attr("shape: shape")
__anon38bbb0e80202(InferenceContext* c) 291     .SetShapeFn([](InferenceContext* c) {
292       // Validate that the shape attr is correct.
293       PartialTensorShape shape;
294       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
295       ShapeHandle passed_shape;
296       TF_RETURN_IF_ERROR(
297           c->MakeShapeFromPartialTensorShape(shape, &passed_shape));
298       if (!c->FullyDefined(passed_shape)) {
299         return errors::InvalidArgument("shape attr must be fully defined.");
300       }
301       ShapeHandle cur;
302       TF_RETURN_IF_ERROR(c->ReplaceDim(
303           passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)),
304           &cur));
305       for (int i = 0; i < c->num_inputs(); ++i) {
306         if (!c->FullyDefined(c->input(i))) {
307           return errors::InvalidArgument(
308               "All input shapes must be fully defined.");
309         }
310         DimensionHandle unused;
311         if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) {
312           return errors::InvalidArgument("Size of first dimension must be 1.");
313         }
314         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
315                                         "From merging shape ", i,
316                                         " with other shapes.");
317       }
318 
319       c->set_output(0, passed_shape);
320 
321       return OkStatus();
322     });
323 
324 REGISTER_OP("Pack")
325     .Input("values: N * T")
326     .Output("output: T")
327     .Attr("N: int >= 1")
328     .Attr("T: type")
329     .Attr("axis: int = 0")
__anon38bbb0e80302(InferenceContext* c) 330     .SetShapeFn([](InferenceContext* c) {
331       // Validate shapes of all inputs are compatible
332       ShapeHandle cur = c->input(c->num_inputs() - 1);
333       for (int i = c->num_inputs() - 2; i >= 0; --i) {
334         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
335                                         "From merging shape ", i,
336                                         " with other shapes.");
337       }
338       if (!c->RankKnown(cur)) {
339         c->set_output(0, c->UnknownShape());
340         return OkStatus();
341       }
342       // Determine the axis that will be added, converting from negative
343       // axes to a positive point per negative indexing rules.
344       int32_t rank = c->Rank(cur);
345       int32_t axis;
346       TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
347 
348       // Copy all dimensions over, inserting a dimension of value #inputs
349       // at <axis>.
350       std::vector<DimensionHandle> dims;
351       int index = 0;
352       while (index < axis) dims.push_back(c->Dim(cur, index++));
353       dims.push_back(c->MakeDim(c->num_inputs()));
354       while (index < rank) dims.push_back(c->Dim(cur, index++));
355 
356       c->set_output(0, c->MakeShape(dims));
357       for (int i = 0; i < c->num_inputs(); ++i) {
358         auto* shape_and_type = c->input_handle_shapes_and_types(i);
359         if (shape_and_type) {
360           if (!c->RelaxOutputHandleShapesAndMergeTypes(0, *shape_and_type)) {
361             c->set_output_handle_shapes_and_types(
362                 0, std::vector<shape_inference::ShapeAndType>({}));
363             break;
364           }
365         }
366       }
367       return OkStatus();
368     });
369 
370 REGISTER_OP("DeepCopy")
371     .Input("x: T")
372     .Output("y: T")
373     .Attr("T: type")
374     .SetIsStateful()
375     .SetShapeFn(UnchangedShape);
376 
377 REGISTER_OP("InplaceUpdate")
378     .Input("x: T")
379     .Input("i: int32")
380     .Input("v: T")
381     .Output("y: T")
382     .Attr("T: type")
383     .SetShapeFn(UnchangedShape);
384 
385 REGISTER_OP("InplaceAdd")
386     .Input("x: T")
387     .Input("i: int32")
388     .Input("v: T")
389     .Output("y: T")
390     .Attr("T: type")
391     .SetShapeFn(UnchangedShape);
392 
393 REGISTER_OP("InplaceSub")
394     .Input("x: T")
395     .Input("i: int32")
396     .Input("v: T")
397     .Output("y: T")
398     .Attr("T: type")
399     .SetShapeFn(UnchangedShape);
400 
401 REGISTER_OP("Empty")
402     .Input("shape: int32")
403     .Output("output: dtype")
404     .Attr("dtype: type")
405     .Attr("init: bool = false")
406     .SetDoNotOptimize()
__anon38bbb0e80402(InferenceContext* c) 407     .SetShapeFn([](InferenceContext* c) {
408       ShapeHandle out;
409       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
410       c->set_output(0, out);
411       return OkStatus();
412     });
413 
414 // --------------------------------------------------------------------------
415 REGISTER_OP("Unpack")
416     .Input("value: T")
417     .Output("output: num * T")
418     .Attr("num: int >= 0")
419     .Attr("T: type")
420     .Attr("axis: int = 0")
__anon38bbb0e80502(InferenceContext* c) 421     .SetShapeFn([](InferenceContext* c) {
422       ShapeHandle s = c->input(0);
423       ShapeHandle out;
424       if (c->RankKnown(s)) {
425         // Determine the axis that will be removed, converting from negative
426         // axes to a positive point per negative indexing rules.
427         int32_t rank = c->Rank(s);
428         int32_t axis;
429         TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
430 
431         // The axis dim matches the number of outputs.
432         DimensionHandle unused;
433         TF_RETURN_IF_ERROR(
434             c->WithValue(c->Dim(s, axis), c->num_outputs(), &unused));
435 
436         // Copy all dimensions, removing the <axis> dimension.
437         std::vector<DimensionHandle> dims;
438         for (int i = 0; i < rank; ++i) {
439           if (i != axis) dims.push_back(c->Dim(s, i));
440         }
441         out = c->MakeShape(dims);
442       } else {
443         // All outputs are the same shape, but it's not known.
444         out = c->UnknownShape();
445       }
446       for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
447       return OkStatus();
448     });
449 
450 REGISTER_OP("UnravelIndex")
451     .Input("indices: Tidx")
452     .Input("dims: Tidx")
453     .Output("output: Tidx")
454     .Attr("Tidx: {int32, int64} = DT_INT32")
__anon38bbb0e80602(InferenceContext* c) 455     .SetShapeFn([](InferenceContext* c) {
456       ShapeHandle indices = c->input(0);
457       ShapeHandle dims;
458       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
459       if (c->RankKnown(indices) && c->Rank(indices) == 0) {
460         c->set_output(0, c->Vector(c->Dim(dims, 0)));
461       } else if (c->RankKnown(indices)) {
462         c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices)));
463       } else {
464         c->set_output(0, c->UnknownShape());
465       }
466       return OkStatus();
467     });
468 
469 REGISTER_OP("BroadcastTo")
470     .Input("input: T")
471     .Input("shape: Tidx")
472     .Output("output: T")
473     .Attr("T: type")
474     .Attr("Tidx: {int32, int64} = DT_INT32")
__anon38bbb0e80702(InferenceContext* c) 475     .SetShapeFn([](InferenceContext* c) {
476       ShapeHandle shape_in = c->input(1);
477       TF_RETURN_IF_ERROR(c->WithRank(shape_in, 1, &shape_in));
478       ShapeHandle out;
479       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
480       if (!c->RankKnown(out)) {
481         // We have no information about the shape of the output.
482         c->set_output(0, out);
483         return OkStatus();
484       }
485 
486       ShapeHandle in = c->input(0);
487       if (!c->RankKnown(in)) {
488         // We have no information about the shape of the input,
489         // nothing to do here.
490         c->set_output(0, out);
491         return OkStatus();
492       }
493       int out_rank = c->Rank(out);
494       TF_RETURN_IF_ERROR(c->WithRankAtMost(in, out_rank, &in));
495       int in_rank = c->Rank(in);
496       for (int i = 0; i < in_rank; ++i) {
497         auto in_dim = c->Dim(in, in_rank - i - 1);
498         if (c->Value(in_dim) > 1) {
499           // If the input dimension is greater than 1 then the output dimension
500           // must be equal to it, since we only broadcast "from left to right".
501           auto out_dim = c->Dim(out, out_rank - i - 1);
502           TF_RETURN_IF_ERROR(c->Merge(in_dim, out_dim, &out_dim));
503           TF_RETURN_IF_ERROR(
504               c->ReplaceDim(out, out_rank - i - 1, out_dim, &out));
505         }
506       }
507       c->set_output(0, out);
508       return OkStatus();
509     });
510 
511 // --------------------------------------------------------------------------
512 // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph
513 // in the N == 1 case to remove the node.
514 REGISTER_OP("Concat")
515     .Input("concat_dim: int32")
516     .Input("values: N * T")
517     .Output("output: T")
518     .Attr("N: int >= 2")
519     .Attr("T: type")
__anon38bbb0e80802(InferenceContext* c) 520     .SetShapeFn([](InferenceContext* c) {
521       return shape_inference::ConcatShape(c, c->num_inputs() - 1);
522     });
523 
524 REGISTER_OP("ConcatV2")
525     .Input("values: N * T")
526     .Input("axis: Tidx")
527     .Output("output: T")
528     .Attr("N: int >= 2")
529     .Attr("T: type")
530     .Attr("Tidx: {int32, int64} = DT_INT32")
531     .SetShapeFn(shape_inference::ConcatV2Shape);
532 
533 // TODO([email protected]): Prefix the op names with underscore if the ops
534 // are not to be made user-accessible.
535 #ifdef INTEL_MKL
536 REGISTER_OP("_MklConcatV2")
537     .Input("values: N * T")
538     .Input("axis: Tidx")
539     .Input("mkl_values: N * uint8")
540     .Input("mkl_axis: uint8")
541     .Output("output: T")
542     .Output("mkl_output: uint8")
543     .Attr("N: int >= 2")
544     .Attr("T: type")
545     .Attr("Tidx: {int32, int64} = DT_INT32")
546     .SetShapeFn(shape_inference::ConcatV2Shape)
547     .Doc(R"doc(
548 MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation.
549 
550 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
551 expected to invoke these operators.
552 )doc");
553 #endif
554 
555 REGISTER_OP("ConcatOffset")
556     .Input("concat_dim: int32")
557     .Input("shape: N * int32")
558     .Output("offset: N * int32")
559     .Attr("N: int >= 2")
__anon38bbb0e80902(InferenceContext* c) 560     .SetShapeFn([](InferenceContext* c) {
561       for (int i = 1; i < c->num_inputs(); ++i) {
562         c->set_output(i - 1, c->input(i));
563       }
564       return OkStatus();
565     });
566 
567 // --------------------------------------------------------------------------
568 REGISTER_OP("Split")
569     .Input("split_dim: int32")
570     .Input("value: T")
571     .Output("output: num_split * T")
572     .Attr("num_split: int >= 1")
573     .Attr("T: type")
__anon38bbb0e80a02(InferenceContext* c) 574     .SetShapeFn([](InferenceContext* c) {
575       DimensionHandle split_dimension;
576       ShapeHandle input = c->input(1);
577       TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
578           0, c->Rank(input), &split_dimension));
579       int num_split = c->num_outputs();
580       ShapeHandle out;
581       if (!c->ValueKnown(split_dimension)) {
582         if (c->RankKnown(input)) {
583           out = c->UnknownShapeOfRank(c->Rank(input));
584         } else {
585           out = c->UnknownShape();
586         }
587       } else {
588         int64_t split_dim = c->Value(split_dimension);
589         TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
590         DimensionHandle split_dim_size;
591         TF_RETURN_WITH_CONTEXT_IF_ERROR(
592             c->Divide(c->Dim(input, split_dim), num_split,
593                       true /* evenly_divisible */, &split_dim_size),
594             "Number of ways to split should evenly divide the split dimension");
595         TF_RETURN_IF_ERROR(
596             c->ReplaceDim(input, split_dim, split_dim_size, &out));
597       }
598       for (int i = 0; i < num_split; ++i) c->set_output(i, out);
599       return OkStatus();
600     });
601 
602 REGISTER_OP("SplitV")
603     .Input("value: T")
604     .Input("size_splits: Tlen")
605     .Input("split_dim: int32")
606     .Output("output: num_split * T")
607     .Attr("num_split: int >= 1")
608     .Attr("T: type")
609     .Attr("Tlen: {int32, int64} = DT_INT64")
__anon38bbb0e80b02(InferenceContext* c) 610     .SetShapeFn([](InferenceContext* c) {
611       DimensionHandle split_dimension;
612       ShapeHandle input = c->input(0);
613       TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
614           2, c->Rank(input), &split_dimension));
615       int32_t num_outputs = c->num_outputs();
616       int32_t rank = c->Rank(input);
617       ShapeHandle output_shape;
618       const Tensor* size_splits = c->input_tensor(1);
619       if (rank == InferenceContext::kUnknownRank) {
620         // If the rank of input tensor is unknown, then return unknown shapes.
621         // Note that the shape of each output can be different.
622         for (int i = 0; i < num_outputs; ++i) {
623           c->set_output(i, c->UnknownShape());
624         }
625       } else if (rank == 0) {
626         // Throw error if input is a scalar.
627         return errors::InvalidArgument("Can't split scalars");
628       } else if (size_splits == nullptr && c->ValueKnown(split_dimension)) {
629         // If split dimension is known, but the sizes are unknown, then
630         // only the split dimension is unknown
631         output_shape = input;
632         for (int i = 0; i < num_outputs; ++i) {
633           TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape,
634                                            c->Value(split_dimension),
635                                            c->UnknownDim(), &output_shape));
636           c->set_output(i, output_shape);
637         }
638       } else if (size_splits == nullptr && !c->ValueKnown(split_dimension)) {
639         // If split dimension or tensor containing the split sizes is unknown,
640         // then return unknown shapes of same rank as input. Note that each
641         // output shape can be different since splitv doesn't always split
642         // tensors evenly.
643         for (int i = 0; i < num_outputs; ++i) {
644           c->set_output(i, c->UnknownShapeOfRank(rank));
645         }
646       } else {
647         // Determine the output shape if split dimension and split sizes are
648         // known.
649         int64_t split_dim = c->Value(split_dimension);
650         TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
651         std::vector<int64_t> data;
652         if (size_splits->dtype() == DT_INT32) {
653           data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0));
654         } else {
655           data =
656               AsInt64<int64_t>(size_splits, size_splits->shape().dim_size(0));
657         }
658         if (num_outputs != data.size()) {
659           return errors::InvalidArgument(
660               "Length of size_splits should be equal to num_outputs");
661         }
662         int64_t total_size = 0;
663         bool has_neg_one = false;
664         for (const auto size : data) {
665           if (size == -1) {
666             if (has_neg_one) {
667               return errors::InvalidArgument(
668                   "size_splits can only have one -1");
669             }
670             has_neg_one = true;
671           } else {
672             total_size += size;
673           }
674         }
675         auto split_dim_size = c->Value(c->Dim(input, split_dim));
676         // If the sizes of the splits are known, then
677         // make sure that the sizes add up to the expected
678         // dimension size, with the possibility of a -1.
679         // Specify the full output shapes.
680         for (int i = 0; i < num_outputs; ++i) {
681           auto size = data[i];
682           if (data[i] == -1 && c->ValueKnown(split_dim_size)) {
683             size = split_dim_size - total_size;
684           }
685           // If we have a negative known size (either explicit, or computed
686           // via -1), then the split sizes are invalid.
687           if (size < -1 || (size == -1 && c->ValueKnown(split_dim_size))) {
688             return errors::InvalidArgument("Split size at index ", i,
689                                            " must be >= 0. Got: ", size);
690           }
691           TF_RETURN_IF_ERROR(
692               c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape));
693           c->set_output(i, output_shape);
694         }
695         if (c->ValueKnown(split_dim_size)) {
696           if (has_neg_one ? total_size > split_dim_size
697                           : total_size != split_dim_size) {
698             return errors::InvalidArgument(
699                 "can't split axis of size ", split_dim_size,
700                 " into pieces of size [", absl::StrJoin(data, ","), "]");
701           }
702         }
703       }
704 
705       return OkStatus();
706     });
707 
708 // --------------------------------------------------------------------------
709 REGISTER_OP("Const")
710     .Output("output: dtype")
711     .Attr("value: tensor")
712     .Attr("dtype: type")
__anon38bbb0e80c02(InferenceContext* c) 713     .SetShapeFn([](InferenceContext* c) {
714       const TensorProto* proto = nullptr;
715       TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
716       TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
717       TensorShape shape(proto->tensor_shape());
718       std::vector<DimensionHandle> dims;
719       dims.reserve(shape.dims());
720       for (int i = 0; i < shape.dims(); ++i) {
721         dims.push_back(c->MakeDim(shape.dim_size(i)));
722       }
723       c->set_output(0, c->MakeShape(dims));
724       return OkStatus();
725     });
726 
727 // Returns a constant tensor on the host.  Useful for writing C++ tests
728 // and benchmarks which run on GPU but require arguments pinned to the host.
729 // Used by test::graph::HostConstant.
730 // value: Attr `value` is the tensor to return.
731 REGISTER_OP("HostConst")
732     .Output("output: dtype")
733     .Attr("value: tensor")
734     .Attr("dtype: type")
735     .SetShapeFn(shape_inference::UnknownShape);
736 
737 // Used executing op-by-op to copy constants to the current device without
738 // serializing tensors as TensorProtos, after a host tensor has been
739 // created. Same behavior as Identity, but no gradient and potentially relaxed
740 // copy semantics.
741 REGISTER_OP("_EagerConst")
742     .Input("input: T")
743     .Output("output: T")
744     .Attr("T: type")
745     .SetShapeFn(shape_inference::UnchangedShape);
746 
747 // --------------------------------------------------------------------------
748 // TODO(mgubin): Update the doc when the freeze_graph script supports converting
749 // into memmapped format.
750 REGISTER_OP("ImmutableConst")
751     .Attr("dtype: type")
752     .Attr("shape: shape")
753     .Attr("memory_region_name: string")
754     .Output("tensor: dtype")
755     .SetShapeFn(shape_inference::ExplicitShape);
756 
757 REGISTER_OP("GuaranteeConst")
758     .Input("input: T")
759     .Output("output: T")
760     .Attr("T: type")
__anon38bbb0e80d02(shape_inference::InferenceContext* c) 761     .SetShapeFn([](shape_inference::InferenceContext* c) {
762       return UnchangedShape(c);
763     })
764     // We don't want this to be optimized away.
765     .SetDoNotOptimize();
766 
767 // --------------------------------------------------------------------------
768 REGISTER_OP("ZerosLike")
769     .Input("x: T")
770     .Output("y: T")
771     .Attr("T: type")
772     .SetShapeFn(shape_inference::UnchangedShape);
773 
774 // --------------------------------------------------------------------------
775 REGISTER_OP("OnesLike")
776     .Input("x: T")
777     .Output("y: T")
778     .Attr(
779         "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, "
780         "uint32, int64, uint64, complex64, complex128, bool}")
781     .SetShapeFn(shape_inference::UnchangedShape);
782 
783 // --------------------------------------------------------------------------
784 REGISTER_OP("Diag")
785     .Input("diagonal: T")
786     .Output("output: T")
787     .Attr(
788         "T: {bfloat16, half, float, double, int32, int64, complex64, "
789         "complex128}")
__anon38bbb0e80e02(InferenceContext* c) 790     .SetShapeFn([](InferenceContext* c) {
791       ShapeHandle in = c->input(0);
792       TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
793       // Output shape is original concatenated with itself.
794       ShapeHandle out;
795       TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out));
796       c->set_output(0, out);
797       return OkStatus();
798     });
799 
800 // --------------------------------------------------------------------------
801 REGISTER_OP("DiagPart")
802     .Input("input: T")
803     .Output("diagonal: T")
804     .Attr(
805         "T: {bfloat16, half, float, double, int32, int64, complex64, "
806         "complex128}")
__anon38bbb0e80f02(InferenceContext* c) 807     .SetShapeFn([](InferenceContext* c) {
808       ShapeHandle in = c->input(0);
809       if (!c->RankKnown(in)) {
810         c->set_output(0, c->UnknownShape());
811         return OkStatus();
812       }
813       // Rank must be even, and result will have rank <rank/2>.
814       const int32_t rank = c->Rank(in);
815       if ((rank % 2) != 0 || rank <= 0) {
816         return errors::InvalidArgument(
817             "Input must have even and non-zero rank, input rank is ", rank);
818       }
819       const int32_t mid = rank / 2;
820 
821       // output dim[i] is the merge of in.dim[i] and in.dim[i+mid].
822       std::vector<DimensionHandle> dims(mid);
823       for (int i = 0; i < mid; ++i) {
824         TF_RETURN_IF_ERROR(
825             c->Merge(c->Dim(in, i), c->Dim(in, i + mid), &dims[i]));
826       }
827       c->set_output(0, c->MakeShape(dims));
828       return OkStatus();
829     });
830 
831 // --------------------------------------------------------------------------
832 REGISTER_OP("MatrixDiag")
833     .Input("diagonal: T")
834     .Output("output: T")
835     .Attr("T: type")
__anon38bbb0e81002(InferenceContext* c) 836     .SetShapeFn([](InferenceContext* c) {
837       ShapeHandle in;
838       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in));
839       if (!c->RankKnown(in)) {
840         c->set_output(0, c->UnknownShape());
841         return OkStatus();
842       }
843       const int32_t rank = c->Rank(in);
844       ShapeHandle out;
845       TF_RETURN_IF_ERROR(
846           c->Concatenate(in, c->Vector(c->Dim(in, rank - 1)), &out));
847       c->set_output(0, out);
848       return OkStatus();
849     });
850 
851 REGISTER_OP("MatrixDiagV2")
852     .Input("diagonal: T")
853     .Input("k: int32")
854     .Input("num_rows: int32")
855     .Input("num_cols: int32")
856     .Input("padding_value: T")
857     .Output("output: T")
858     .Attr("T: type")
859     .SetShapeFn(shape_inference::MatrixDiagV2Shape);
860 
861 REGISTER_OP("MatrixDiagV3")
862     .Input("diagonal: T")
863     .Input("k: int32")
864     .Input("num_rows: int32")
865     .Input("num_cols: int32")
866     .Input("padding_value: T")
867     .Output("output: T")
868     .Attr("T: type")
869     .Attr(
870         "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
871         "'RIGHT_LEFT'")
872     .SetShapeFn(shape_inference::MatrixDiagV2Shape);
873 
874 // --------------------------------------------------------------------------
875 REGISTER_OP("MatrixSetDiag")
876     .Input("input: T")
877     .Input("diagonal: T")
878     .Output("output: T")
879     .Attr("T: type")
__anon38bbb0e81102(InferenceContext* c) 880     .SetShapeFn([](InferenceContext* c) {
881       ShapeHandle input;
882       ShapeHandle diag;
883       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
884       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag));
885       if (c->RankKnown(input)) {
886         TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag));
887       }
888       DimensionHandle smallest_dim;
889       TF_RETURN_IF_ERROR(
890           c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim));
891       TF_RETURN_IF_ERROR(
892           c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim));
893 
894       ShapeHandle output = input;
895       if (c->RankKnown(diag) && !c->FullyDefined(input)) {
896         // Try to infer parts of shape from diag.
897         ShapeHandle diag_batch_shape;
898         TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_batch_shape));
899         TF_RETURN_IF_ERROR(
900             c->Concatenate(diag_batch_shape, c->UnknownShapeOfRank(2), &diag));
901         TF_RETURN_IF_ERROR(c->Merge(input, diag, &output));
902       }
903       c->set_output(0, output);
904       return OkStatus();
905     });
906 
907 REGISTER_OP("MatrixSetDiagV2")
908     .Input("input: T")
909     .Input("diagonal: T")
910     .Input("k: int32")
911     .Output("output: T")
912     .Attr("T: type")
913     .SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
914 
915 REGISTER_OP("MatrixSetDiagV3")
916     .Input("input: T")
917     .Input("diagonal: T")
918     .Input("k: int32")
919     .Output("output: T")
920     .Attr("T: type")
921     .Attr(
922         "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
923         "'RIGHT_LEFT'")
924     .SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
925 
926 // --------------------------------------------------------------------------
927 REGISTER_OP("MatrixDiagPart")
928     .Input("input: T")
929     .Output("diagonal: T")
930     .Attr("T: type")
__anon38bbb0e81202(InferenceContext* c) 931     .SetShapeFn([](InferenceContext* c) {
932       ShapeHandle in;
933       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &in));
934       if (!c->RankKnown(in)) {
935         c->set_output(0, c->UnknownShape());
936         return OkStatus();
937       }
938       const int32_t rank = c->Rank(in);
939       std::vector<DimensionHandle> dims;
940       dims.reserve(rank - 2);
941       for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i));
942 
943       DimensionHandle min_dim;
944       TF_RETURN_IF_ERROR(
945           c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim));
946       dims.push_back(min_dim);
947       c->set_output(0, c->MakeShape(dims));
948       return OkStatus();
949     });
950 
951 REGISTER_OP("MatrixDiagPartV2")
952     .Input("input: T")
953     .Input("k: int32")
954     .Input("padding_value: T")
955     .Output("diagonal: T")
956     .Attr("T: type")
957     .SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
958 
959 REGISTER_OP("MatrixDiagPartV3")
960     .Input("input: T")
961     .Input("k: int32")
962     .Input("padding_value: T")
963     .Output("diagonal: T")
964     .Attr("T: type")
965     .Attr(
966         "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
967         "'RIGHT_LEFT'")
968     .SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
969 
970 // --------------------------------------------------------------------------
971 REGISTER_OP("MatrixBandPart")
972     .Input("input: T")
973     .Input("num_lower: Tindex")
974     .Input("num_upper: Tindex")
975     .Output("band: T")
976     .Attr("T: type")
977     .Attr("Tindex: {int32, int64} = DT_INT64")
978     .SetShapeFn(shape_inference::UnchangedShape);
979 
980 // --------------------------------------------------------------------------
981 REGISTER_OP("Reverse")
982     .Input("tensor: T")
983     .Input("dims: bool")
984     .Output("output: T")
985     .Attr(
986         "T: {uint8, int8, uint16, int16, uint32, int32, uint64, int64, bool, "
987         "bfloat16, half, float, double, complex64, complex128, string}")
__anon38bbb0e81302(InferenceContext* c) 988     .SetShapeFn([](InferenceContext* c) {
989       ShapeHandle input = c->input(0);
990       ShapeHandle dims;
991       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
992       DimensionHandle dims_dim = c->Dim(dims, 0);
993       if (c->ValueKnown(dims_dim)) {
994         TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(dims_dim), &input));
995       }
996       if (c->Rank(input) > 8) {
997         return errors::InvalidArgument(
998             "reverse does not work on tensors with more than 8 dimensions");
999       }
1000       c->set_output(0, input);
1001       return OkStatus();
1002     });
1003 
1004 // --------------------------------------------------------------------------
1005 REGISTER_OP("ReverseV2")
1006     .Input("tensor: T")
1007     .Input("axis: Tidx")
1008     .Output("output: T")
1009     .Attr("Tidx: {int32, int64} = DT_INT32")
1010     .Attr(
1011         "T: {uint8, int8, uint16, int16, int32, uint32, int64, uint64, bool, "
1012         "bfloat16, half, float, double, complex64, complex128, string}")
__anon38bbb0e81402(InferenceContext* c) 1013     .SetShapeFn([](InferenceContext* c) {
1014       ShapeHandle input = c->input(0);
1015       ShapeHandle axis;
1016       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis));
1017       if (c->Rank(input) > 8) {
1018         return errors::InvalidArgument(
1019             "reverse does not work on tensors with more than 8 dimensions");
1020       }
1021       const Tensor* axis_tensor = c->input_tensor(1);
1022       if (axis_tensor != nullptr && c->RankKnown(input)) {
1023         int32_t rank = c->Rank(input);
1024         std::vector<int64_t> axis_value;
1025         if (axis_tensor->dtype() == DT_INT32) {
1026           axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements());
1027         } else {
1028           axis_value =
1029               AsInt64<int64_t>(axis_tensor, axis_tensor->NumElements());
1030         }
1031         std::vector<bool> axes_dense(c->Rank(input), false);
1032         for (int i = 0; i < axis_value.size(); i++) {
1033           int64_t canonical_axis =
1034               axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i];
1035           if (canonical_axis < 0 || canonical_axis >= rank) {
1036             return errors::InvalidArgument("'axis'[", i, "] = ", axis_value[i],
1037                                            " is out of valid range [", 0, ", ",
1038                                            rank - 1);
1039           }
1040           if (axes_dense[canonical_axis]) {
1041             return errors::InvalidArgument("axis ", canonical_axis,
1042                                            " specified more than once.");
1043           }
1044           axes_dense[canonical_axis] = true;
1045         }
1046       }
1047       c->set_output(0, input);
1048       return OkStatus();
1049     });
1050 
1051 // --------------------------------------------------------------------------
1052 REGISTER_OP("EditDistance")
1053     .Input("hypothesis_indices: int64")
1054     .Input("hypothesis_values: T")
1055     .Input("hypothesis_shape: int64")
1056     .Input("truth_indices: int64")
1057     .Input("truth_values: T")
1058     .Input("truth_shape: int64")
1059     .Attr("normalize: bool = true")
1060     .Attr("T: type")
1061     .Output("output: float")
__anon38bbb0e81502(InferenceContext* c) 1062     .SetShapeFn([](InferenceContext* c) {
1063       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
1064           c, c->input(0), c->input(1), c->input(2)));
1065       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
1066           c, c->input(3), c->input(4), c->input(5)));
1067       const Tensor* hypothesis_shape_t = c->input_tensor(2);
1068       const Tensor* truth_shape_t = c->input_tensor(5);
1069       if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {
1070         // We need to know the runtime shape of the two tensors,
1071         // or else the output shape is unknown.
1072         return shape_inference::UnknownShape(c);
1073       }
1074 
1075       if (hypothesis_shape_t->NumElements() != truth_shape_t->NumElements()) {
1076         return errors::InvalidArgument(
1077             "Num elements of hypothesis_shape does not match truth_shape: ",
1078             hypothesis_shape_t->NumElements(), " vs. ",
1079             truth_shape_t->NumElements());
1080       }
1081 
1082       auto h_values = hypothesis_shape_t->flat<int64_t>();
1083       auto t_values = truth_shape_t->flat<int64_t>();
1084       std::vector<DimensionHandle> dims(hypothesis_shape_t->NumElements() - 1);
1085       for (int i = 0; i < dims.size(); ++i) {
1086         dims[i] = c->MakeDim(std::max(h_values(i), t_values(i)));
1087       }
1088 
1089       c->set_output(0, c->MakeShape(dims));
1090       return OkStatus();
1091     });
1092 
1093 // --------------------------------------------------------------------------
1094 REGISTER_OP("Fill")
1095     .Input("dims: index_type")
1096     .Input("value: T")
1097     .Output("output: T")
1098     .Attr("T: type")
1099     .Attr("index_type: {int32, int64} = DT_INT32")
__anon38bbb0e81602(InferenceContext* c) 1100     .SetShapeFn([](InferenceContext* c) {
1101       DataType index_type = DT_INT32;
1102       Status s = c->GetAttr("index_type", &index_type);
1103       if (!s.ok() && s.code() != error::NOT_FOUND) {
1104         return s;
1105       }
1106       ShapeHandle unused;
1107       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1108       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1109 
1110       const Tensor* t = c->input_tensor(0);
1111       if (t != nullptr) {
1112         for (int i = 0; i < t->NumElements(); ++i) {
1113           if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) ||
1114               (index_type == DT_INT64 && t->vec<int64_t>()(i) < 0)) {
1115             return errors::InvalidArgument("Fill dimensions must be >= 0");
1116           }
1117         }
1118       }
1119 
1120       ShapeHandle out;
1121       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1122       c->set_output(0, out);
1123 
1124       auto* shape_and_type = c->input_handle_shapes_and_types(1);
1125       if (shape_and_type) {
1126         c->set_output_handle_shapes_and_types(0, *shape_and_type);
1127       }
1128 
1129       return OkStatus();
1130     });
1131 
1132 // --------------------------------------------------------------------------
1133 REGISTER_OP("_ParallelConcatStart")
1134     .Output("output: dtype")
1135     .Attr("shape: shape")
1136     .Attr("dtype: type")
1137     .SetIsStateful()
1138     .SetShapeFn(shape_inference::ExplicitShape)
1139     .Doc(R"doc(
1140 Creates an empty Tensor with shape `shape` and type `dtype`.
1141 
1142 The memory can optionally be initialized. This is usually useful in
1143 conjunction with inplace operations.
1144 
1145 shape: 1-D `Tensor` indicating the shape of the output.
1146 dtype: The element type of the returned tensor.
1147 output: An empty Tensor of the specified type.
1148 )doc");
1149 
1150 // --------------------------------------------------------------------------
1151 REGISTER_OP("_ParallelConcatUpdate")
1152     .Input("value: T")
1153     .Input("update: T")
1154     .Output("output: T")
1155     .Attr("T: type")
1156     .Attr("loc: int")
1157     .SetShapeFn(shape_inference::UnchangedShape)
1158     .Doc(R"doc(
1159 Updates input `value` at `loc` with `update`.
1160 
1161 If you use this function you will almost certainly want to add
1162 a control dependency as done in the implementation of parallel_stack to
1163 avoid race conditions.
1164 
1165 value: A `Tensor` object that will be updated in-place.
1166 loc: A scalar indicating the index of the first dimension such that
1167          value[loc, :] is updated.
1168 update: A `Tensor` of rank one less than `value` if `loc` is a scalar,
1169         otherwise of rank equal to `value` that contains the new values
1170         for `value`.
1171 output: `value` that has been updated accordingly.
1172 )doc");
1173 
1174 // --------------------------------------------------------------------------
1175 REGISTER_OP("Gather")
1176     .Input("params: Tparams")
1177     .Input("indices: Tindices")
1178     .Attr("validate_indices: bool = true")
1179     .Output("output: Tparams")
1180     .Attr("Tparams: type")
1181     .Attr("Tindices: {int32,int64}")
__anon38bbb0e81702(InferenceContext* c) 1182     .SetShapeFn([](InferenceContext* c) {
1183       ShapeHandle unused;
1184       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
1185       ShapeHandle params_subshape;
1186       TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, &params_subshape));
1187       ShapeHandle indices_shape = c->input(1);
1188       ShapeHandle out;
1189       TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out));
1190       c->set_output(0, out);
1191       return OkStatus();
1192     });
1193 
1194 // --------------------------------------------------------------------------
1195 REGISTER_OP("GatherV2")
1196     .Input("params: Tparams")
1197     .Input("indices: Tindices")
1198     .Input("axis: Taxis")
1199     .Attr("batch_dims: int = 0")
1200     .Output("output: Tparams")
1201     .Attr("Tparams: type")
1202     .Attr("Tindices: {int16, int32,int64}")
1203     .Attr("Taxis: {int32,int64}")
__anon38bbb0e81802(InferenceContext* c) 1204     .SetShapeFn([](InferenceContext* c) {
1205       ShapeHandle params_shape;
1206       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &params_shape));
1207 
1208       ShapeHandle indices_shape = c->input(1);
1209       ShapeHandle unused_axis_shape;
1210       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_axis_shape));
1211       const Tensor* axis_t = c->input_tensor(2);
1212 
1213       // If axis is unknown, we can only infer that the result is params_rank +
1214       // indices_rank - 1.
1215       if (axis_t == nullptr) {
1216         if (c->RankKnown(params_shape) && c->RankKnown(indices_shape)) {
1217           int32_t batch_dims;
1218           TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims));
1219           c->set_output(0, c->UnknownShapeOfRank(c->Rank(params_shape) +
1220                                                  c->Rank(indices_shape) - 1 -
1221                                                  batch_dims));
1222         } else {
1223           c->set_output(0, c->UnknownShape());
1224         }
1225         return OkStatus();
1226       }
1227 
1228       // Note, axis can be negative.
1229       int64_t axis = 0;
1230       if (axis_t->dtype() == DT_INT32) {
1231         axis = axis_t->scalar<int32>()();
1232       } else {
1233         axis = axis_t->scalar<int64_t>()();
1234       }
1235 
1236       // Check that params has rank of at least axis + 1.
1237       ShapeHandle unused;
1238       TF_RETURN_IF_ERROR(c->WithRankAtLeast(
1239           params_shape, axis < 0 ? -axis : axis + 1, &unused));
1240 
1241       // Note, batch_dims can be negative.
1242       int32_t batch_dims;
1243       TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims));
1244       // -rank(indices) <= batch_dims <= rank(indices)
1245       TF_RETURN_IF_ERROR(
1246           c->WithRankAtLeast(indices_shape, std::abs(batch_dims), &unused));
1247       if (batch_dims < 0) {
1248         batch_dims += c->Rank(indices_shape);
1249       }
1250       // rank(params) > batch_dims
1251       TF_RETURN_IF_ERROR(
1252           c->WithRankAtLeast(params_shape, batch_dims + 1, &unused));
1253 
1254       ShapeHandle params_outer_subshape;
1255       TF_RETURN_IF_ERROR(
1256           c->Subshape(params_shape, 0, axis, &params_outer_subshape));
1257 
1258       ShapeHandle indices_inner_subshape;
1259       TF_RETURN_IF_ERROR(
1260           c->Subshape(indices_shape, batch_dims, &indices_inner_subshape));
1261 
1262       ShapeHandle out;
1263       TF_RETURN_IF_ERROR(
1264           c->Concatenate(params_outer_subshape, indices_inner_subshape, &out));
1265 
1266       // Slice from axis + 1 to the end of params_shape to collect the inner
1267       // dimensions of the result. Special case -1 here since -1 + 1 wraps, and
1268       // we slice from 0 to the end of shape. Subshape() handles all other
1269       // out-of-bounds checking.
1270       if (axis != -1) {
1271         ShapeHandle params_inner_subshape;
1272         TF_RETURN_IF_ERROR(
1273             c->Subshape(params_shape, axis + 1, &params_inner_subshape));
1274         TF_RETURN_IF_ERROR(c->Concatenate(out, params_inner_subshape, &out));
1275       }
1276 
1277       c->set_output(0, out);
1278       return OkStatus();
1279     });
1280 
1281 // --------------------------------------------------------------------------
1282 REGISTER_OP("GatherNd")
1283     .Input("params: Tparams")
1284     .Input("indices: Tindices")
1285     .Output("output: Tparams")
1286     .Attr("Tparams: type")
1287     .Attr("Tindices: {int16, int32,int64}")
1288     .SetShapeFn(shape_inference::GatherNdShape);
1289 
1290 // --------------------------------------------------------------------------
1291 REGISTER_OP("Identity")
1292     .Input("input: T")
1293     .Output("output: T")
1294     .Attr("T: type")
1295     .SetForwardTypeFn(full_type::ReplicateInput())
1296     .SetShapeFn(shape_inference::UnchangedShape);
1297 
1298 REGISTER_OP("Snapshot")
1299     .Input("input: T")
1300     .Output("output: T")
1301     .Attr("T: type")
1302     .SetShapeFn(shape_inference::UnchangedShape);
1303 
1304 #ifdef INTEL_MKL
1305 REGISTER_OP("_MklIdentity")
1306     .Input("input: T")
1307     .Input("mkl_input: uint8")
1308     .Output("output: T")
1309     .Output("mkl_output: uint8")
1310     .Attr("T: type")
1311     .SetShapeFn(shape_inference::UnchangedShape)
1312     .Doc(R"Doc( Mkl implementation of IdentityOp
1313 )Doc");
1314 #endif
1315 
1316 REGISTER_OP("IdentityN")
1317     .Input("input: T")
1318     .Output("output: T")
1319     .Attr("T: list(type)")
__anon38bbb0e81902(shape_inference::InferenceContext* c) 1320     .SetShapeFn([](shape_inference::InferenceContext* c) {
1321       std::vector<ShapeHandle> input;
1322       TF_RETURN_IF_ERROR(c->input("input", &input));
1323       TF_RETURN_IF_ERROR(c->set_output("output", input));
1324       // If any of the input shapes are not known, we should return error.
1325       for (int i = 0; i < input.size(); i++) {
1326         if (!input[i].Handle()) {
1327           return errors::InvalidArgument(absl::StrCat(
1328               "Cannot infer output shape #", i,
1329               " for IdentityN node because input shape #", i, " is unknown."));
1330         }
1331       }
1332       return OkStatus();
1333     });
1334 
1335 // --------------------------------------------------------------------------
1336 REGISTER_OP("RefIdentity")
1337     .Input("input: Ref(T)")
1338     .Output("output: Ref(T)")
1339     .Attr("T: type")
1340     .SetShapeFn(shape_inference::UnchangedShape)
1341     .SetAllowsUninitializedInput();
1342 
1343 // --------------------------------------------------------------------------
1344 REGISTER_OP("DebugGradientIdentity")
1345     .Input("input: T")
1346     .Output("output: T")
1347     .Attr("T: type")
1348     .SetShapeFn(shape_inference::UnchangedShape)
1349     .SetAllowsUninitializedInput();
1350 
1351 REGISTER_OP("DebugGradientRefIdentity")
1352     .Input("input: Ref(T)")
1353     .Output("output: Ref(T)")
1354     .Attr("T: type")
1355     .SetShapeFn(shape_inference::UnchangedShape)
1356     .SetAllowsUninitializedInput();
1357 
1358 // --------------------------------------------------------------------------
1359 REGISTER_OP("StopGradient")
1360     .Input("input: T")
1361     .Output("output: T")
1362     .Attr("T: type")
1363     .SetShapeFn(shape_inference::UnchangedShape);
1364 
1365 REGISTER_OP("PreventGradient")
1366     .Input("input: T")
1367     .Output("output: T")
1368     .Attr("T: type")
1369     .Attr("message: string = ''")
1370     .SetShapeFn(shape_inference::UnchangedShape);
1371 
1372 // --------------------------------------------------------------------------
1373 REGISTER_OP("CheckNumerics")
1374     .Input("tensor: T")
1375     .Output("output: T")
1376     .Attr("T: {bfloat16, half, float, double}")
1377     .Attr("message: string")
1378     .SetIsStateful()
1379     .SetShapeFn(shape_inference::UnchangedShape);
1380 
1381 // --------------------------------------------------------------------------
1382 REGISTER_OP("CheckNumericsV2")
1383     .Input("tensor: T")
1384     .Output("output: T")
1385     .Attr("T: {bfloat16, half, float, double}")
1386     .Attr("message: string")
1387     .SetIsStateful()
1388     .SetShapeFn(shape_inference::UnchangedShape);
1389 
1390 // --------------------------------------------------------------------------
1391 REGISTER_OP("Reshape")
1392     .Input("tensor: T")
1393     .Input("shape: Tshape")
1394     .Output("output: T")
1395     .Attr("T: type")
1396     .Attr("Tshape: {int32, int64} = DT_INT32")
__anon38bbb0e81a02(InferenceContext* c) 1397     .SetShapeFn([](InferenceContext* c) {
1398       return SetOutputShapeForReshape(c);
1399     });
1400 
1401 #ifdef INTEL_MKL
1402 REGISTER_OP("_MklReshape")
1403     .Input("tensor: T")
1404     .Input("shape: Tshape")
1405     .Input("mkl_tensor: uint8")
1406     .Input("mkl_shape: uint8")
1407     .Output("output: T")
1408     .Output("mkl_output: uint8")
1409     .Attr("T: type")
1410     .Attr("Tshape: {int32, int64} = DT_INT32")
__anon38bbb0e81b02(InferenceContext* c) 1411     .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); })
1412     .Doc(R"Doc( MKL implementation of ReshapeOp.
1413 )Doc");
1414 #endif  // INTEL_MKL
1415 
1416 // --------------------------------------------------------------------------
1417 REGISTER_OP("InvertPermutation")
1418     .Input("x: T")
1419     .Output("y: T")
1420     .Attr("T: {int32, int64} = DT_INT32")
__anon38bbb0e81c02(InferenceContext* c) 1421     .SetShapeFn([](InferenceContext* c) {
1422       ShapeHandle x;
1423       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
1424       c->set_output(0, x);
1425       return OkStatus();
1426     });
1427 
1428 // --------------------------------------------------------------------------
1429 REGISTER_OP("Transpose")
1430     .Input("x: T")
1431     .Input("perm: Tperm")
1432     .Output("y: T")
1433     .Attr("T: type")
1434     .Attr("Tperm: {int32, int64} = DT_INT32")
1435     .SetShapeFn(TransposeShapeFn);
1436 
1437 #ifdef INTEL_MKL
1438 REGISTER_OP("_MklTranspose")
1439     .Input("x: T")
1440     .Input("perm: Tperm")
1441     .Output("y: T")
1442     .Attr("T: type")
1443     .Attr("Tperm: {int32, int64} = DT_INT32")
1444     .SetShapeFn(TransposeShapeFn);
1445 #endif  // INTEL_MKL
1446 
1447 // --------------------------------------------------------------------------
1448 REGISTER_OP("ConjugateTranspose")
1449     .Input("x: T")
1450     .Input("perm: Tperm")
1451     .Output("y: T")
1452     .Attr("T: type")
1453     .Attr("Tperm: {int32, int64} = DT_INT32")
1454     .SetShapeFn(TransposeShapeFn);
1455 
1456 #ifdef INTEL_MKL
1457 REGISTER_OP("_MklConjugateTranspose")
1458     .Input("x: T")
1459     .Input("perm: Tperm")
1460     .Output("y: T")
1461     .Attr("T: type")
1462     .Attr("Tperm: {int32, int64} = DT_INT32")
1463     .SetShapeFn(TransposeShapeFn);
1464 #endif  // INTEL_MKL
1465 
1466 // --------------------------------------------------------------------------
1467 namespace {
UniqueIdxShapeFn(InferenceContext * c)1468 Status UniqueIdxShapeFn(InferenceContext* c) {
1469   ShapeHandle input = c->input(0);
1470   const Tensor* axis_t = c->input_tensor(1);
1471   if (axis_t == nullptr || !c->RankKnown(input)) {
1472     c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
1473     return OkStatus();
1474   }
1475 
1476   if (c->Rank(c->input(1)) != 1) {
1477     return errors::InvalidArgument("axis expects a 1D vector.");
1478   }
1479 
1480   int32_t n = axis_t->NumElements();
1481   if (n == 0) {
1482     if (c->Rank(input) != 1) {
1483       return errors::InvalidArgument("x expects a 1D vector.");
1484     }
1485     c->set_output(1, input);
1486     return OkStatus();
1487   } else if (n == 1) {
1488     int64_t axis;
1489     if (axis_t->dtype() == DT_INT32) {
1490       axis = static_cast<int64_t>(axis_t->flat<int32>()(0));
1491     } else {
1492       axis = axis_t->flat<int64_t>()(0);
1493     }
1494 
1495     int64_t input_rank = c->Rank(input);
1496     if (axis < -input_rank || axis >= input_rank) {
1497       return errors::InvalidArgument("axis expects to be in the range [",
1498                                      -input_rank, ", ", input_rank, ")");
1499     }
1500     if (axis < 0) {
1501       axis += input_rank;
1502     }
1503     c->set_output(1, c->Vector(c->Dim(input, axis)));
1504     return OkStatus();
1505   }
1506   return errors::InvalidArgument(
1507       "axis does not support input tensors larger than 1 elements.");
1508 }
1509 }  // namespace
1510 
1511 REGISTER_OP("Unique")
1512     .Input("x: T")
1513     .Output("y: T")
1514     .Output("idx: out_idx")
1515     .Attr("T: type")
1516     .Attr("out_idx: {int32, int64} = DT_INT32")
__anon38bbb0e81e02(InferenceContext* c) 1517     .SetShapeFn([](InferenceContext* c) {
1518       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1519       c->set_output(1, c->input(0));
1520       // Assert that the input rank is 1.
1521       ShapeHandle dummy;
1522       return c->WithRank(c->input(0), 1, &dummy);
1523     });
1524 
1525 REGISTER_OP("UniqueV2")
1526     .Input("x: T")
1527     .Input("axis: Taxis")
1528     .Output("y: T")
1529     .Output("idx: out_idx")
1530     .Attr("T: type")
1531     .Attr("Taxis: {int32,int64} = DT_INT64")
1532     .Attr("out_idx: {int32, int64} = DT_INT32")
__anon38bbb0e81f02(InferenceContext* c) 1533     .SetShapeFn([](InferenceContext* c) {
1534       c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
1535       TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
1536       return OkStatus();
1537     });
1538 
1539 // --------------------------------------------------------------------------
1540 REGISTER_OP("UniqueWithCounts")
1541     .Input("x: T")
1542     .Output("y: T")
1543     .Output("idx: out_idx")
1544     .Output("count: out_idx")
1545     .Attr("T: type")
1546     .Attr("out_idx: {int32, int64} = DT_INT32")
__anon38bbb0e82002(InferenceContext* c) 1547     .SetShapeFn([](InferenceContext* c) {
1548       auto uniq = c->Vector(InferenceContext::kUnknownDim);
1549       c->set_output(0, uniq);
1550       c->set_output(1, c->input(0));
1551       c->set_output(2, uniq);
1552       return OkStatus();
1553     });
1554 
1555 REGISTER_OP("UniqueWithCountsV2")
1556     .Input("x: T")
1557     .Input("axis: Taxis")
1558     .Output("y: T")
1559     .Output("idx: out_idx")
1560     .Output("count: out_idx")
1561     .Attr("T: type")
1562     .Attr("Taxis: {int32,int64} = DT_INT64")
1563     .Attr("out_idx: {int32, int64} = DT_INT32")
__anon38bbb0e82102(InferenceContext* c) 1564     .SetShapeFn([](InferenceContext* c) {
1565       c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
1566       TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
1567       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
1568       return OkStatus();
1569     });
1570 
1571 namespace {
1572 
ShapeShapeFn(InferenceContext * c)1573 Status ShapeShapeFn(InferenceContext* c) {
1574   for (int i = 0; i < c->num_inputs(); ++i) {
1575     DimensionHandle dim;
1576     if (c->RankKnown(c->input(i))) {
1577       dim = c->MakeDim(c->Rank(c->input(i)));
1578     } else {
1579       dim = c->UnknownDim();
1580     }
1581     c->set_output(i, c->Vector(dim));
1582   }
1583   return OkStatus();
1584 }
1585 
1586 }  // namespace
1587 
1588 // --------------------------------------------------------------------------
1589 REGISTER_OP("Shape")
1590     .Input("input: T")
1591     .Output("output: out_type")
1592     .Attr("T: type")
1593     .Attr("out_type: {int32, int64} = DT_INT32")
1594     .SetShapeFn(ShapeShapeFn);
1595 
1596 REGISTER_OP("ShapeN")
1597     .Input("input: N * T")
1598     .Output("output: N * out_type")
1599     .Attr("N: int")
1600     .Attr("T: type")
1601     .Attr("out_type: {int32, int64} = DT_INT32")
1602     .SetShapeFn(ShapeShapeFn);
1603 
1604 REGISTER_OP("EnsureShape")
1605     .Input("input: T")
1606     .Output("output: T")
1607     .Attr("shape: shape")
1608     .Attr("T: type")
__anon38bbb0e82302(InferenceContext* c) 1609     .SetShapeFn([](InferenceContext* c) {
1610       // Merges desired shape and statically known shape of input
1611       PartialTensorShape desired_shape;
1612       TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
1613 
1614       int rank = desired_shape.dims();
1615       ShapeHandle input_shape_handle;
1616       ShapeHandle desired_shape_handle;
1617       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle));
1618       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
1619           desired_shape, &desired_shape_handle));
1620 
1621       ShapeHandle merged_shape;
1622       TF_RETURN_IF_ERROR(
1623           c->Merge(desired_shape_handle, input_shape_handle, &merged_shape));
1624       c->set_output(0, merged_shape);
1625       return OkStatus();
1626     });
1627 
1628 // --------------------------------------------------------------------------
1629 REGISTER_OP("ReverseSequence")
1630     .Input("input: T")
1631     .Input("seq_lengths: Tlen")
1632     .Output("output: T")
1633     .Attr("seq_dim: int")
1634     .Attr("batch_dim: int = 0")
1635     .Attr("T: type")
1636     .Attr("Tlen: {int32, int64} = DT_INT64")
__anon38bbb0e82402(InferenceContext* c) 1637     .SetShapeFn([](InferenceContext* c) {
1638       ShapeHandle input = c->input(0);
1639       ShapeHandle seq_lens_shape;
1640       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape));
1641 
1642       int64_t seq_dim;
1643       TF_RETURN_IF_ERROR(c->GetAttr("seq_dim", &seq_dim));
1644       int64_t batch_dim;
1645       TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim));
1646 
1647       if (!c->RankKnown(input)) {
1648         return shape_inference::UnknownShape(c);
1649       }
1650 
1651       // Validate batch_dim and seq_dim against input.
1652       const int32_t input_rank = c->Rank(input);
1653       if (batch_dim >= input_rank) {
1654         return errors::InvalidArgument(
1655             "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
1656       }
1657 
1658       if (seq_dim >= input_rank) {
1659         return errors::InvalidArgument(
1660             "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
1661       }
1662 
1663       // To prevent out of bound access when calling c->Dim(input, batch_dim),
1664       // batch_dim range [-1 * input rank, input rank) is allowed. However,
1665       // the op implementation has a stricter bound for batch_dim requiring >= 0
1666       // value. Thus, perform strict check here.
1667       if (batch_dim < 0) {
1668         return errors::InvalidArgument("batch_dim must be >=0, got ",
1669                                        batch_dim);
1670       }
1671 
1672       DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
1673       TF_RETURN_IF_ERROR(
1674           c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim));
1675 
1676       // Replace batch_dim of input with batch_size
1677       ShapeHandle output_shape;
1678       TF_RETURN_IF_ERROR(
1679           c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape));
1680       c->set_output(0, output_shape);
1681       return OkStatus();
1682     });
1683 
1684 // --------------------------------------------------------------------------
1685 REGISTER_OP("Rank")
1686     .Input("input: T")
1687     .Output("output: int32")
1688     .Attr("T: type")
1689     .SetShapeFn(shape_inference::ScalarShape);
1690 
1691 // --------------------------------------------------------------------------
1692 REGISTER_OP("Size")
1693     .Input("input: T")
1694     .Output("output: out_type")
1695     .Attr("T: type")
1696     .Attr("out_type: {int32, int64} = DT_INT32")
1697     .SetShapeFn(shape_inference::ScalarShape);
1698 
1699 // --------------------------------------------------------------------------
1700 REGISTER_OP("Slice")
1701     .Input("input: T")
1702     .Input("begin: Index")
1703     .Input("size: Index")
1704     .Output("output: T")
1705     .Attr("T: type")
1706     .Attr("Index: {int32,int64}")
1707     .SetShapeFn(shape_inference::SliceShape);
1708 
1709 #ifdef INTEL_MKL
1710 REGISTER_OP("_MklSlice")
1711     .Input("input: T")
1712     .Input("begin: Index")
1713     .Input("size: Index")
1714     .Input("mkl_input: uint8")
1715     .Input("mkl_begin: uint8")
1716     .Input("mkl_size: uint8")
1717     .Output("output: T")
1718     .Output("mkl_output: uint8")
1719     .Attr("T: type")
1720     .Attr("Index: {int32,int64}")
1721     .SetShapeFn(shape_inference::SliceShape);
1722 #endif
1723 
1724 REGISTER_OP("StridedSlice")
1725     .Input("input: T")
1726     .Input("begin: Index")
1727     .Input("end: Index")
1728     .Input("strides: Index")
1729     .Output("output: T")
1730     .Attr("T: type")
1731     .Attr("Index: {int16, int32, int64}")
1732     .Attr("begin_mask: int = 0")
1733     .Attr("end_mask: int = 0")
1734     .Attr("ellipsis_mask: int = 0")
1735     .Attr("new_axis_mask: int = 0")
1736     .Attr("shrink_axis_mask: int = 0")
__anon38bbb0e82502(InferenceContext* c) 1737     .SetShapeFn([](InferenceContext* c) {
1738       ShapeHandle input = c->input(0);
1739       ShapeHandle begin_shape, end_shape, strides_shape;
1740       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
1741       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &end_shape));
1742       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &strides_shape));
1743       TF_RETURN_IF_ERROR(c->Merge(begin_shape, end_shape, &begin_shape));
1744       TF_RETURN_IF_ERROR(c->Merge(begin_shape, strides_shape, &begin_shape));
1745       DimensionHandle sparse_dims_dim = c->Dim(begin_shape, 0);
1746 
1747       const Tensor* strides_value = c->input_tensor(3);
1748       // TODO(aselle,allenl): If we had a stride_mask it would be possible to do
1749       // more shape inference here (e.g. for x[3, ::T]).
1750       if (!c->RankKnown(input) || !c->ValueKnown(sparse_dims_dim) ||
1751           strides_value == nullptr) {
1752         c->set_output(0, c->UnknownShape());
1753         return OkStatus();
1754       }
1755 
1756       PartialTensorShape input_shape({});
1757       for (int i = 0; i < c->Rank(input); ++i) {
1758         auto dim = c->Dim(input, i);
1759         input_shape.AddDim(c->ValueKnown(dim) ? c->Value(dim) : -1);
1760       }
1761 
1762       int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask,
1763           shrink_axis_mask;
1764       TF_RETURN_IF_ERROR(c->GetAttr("begin_mask", &begin_mask));
1765       TF_RETURN_IF_ERROR(c->GetAttr("end_mask", &end_mask));
1766       TF_RETURN_IF_ERROR(c->GetAttr("ellipsis_mask", &ellipsis_mask));
1767       TF_RETURN_IF_ERROR(c->GetAttr("new_axis_mask", &new_axis_mask));
1768       TF_RETURN_IF_ERROR(c->GetAttr("shrink_axis_mask", &shrink_axis_mask));
1769 
1770       const Tensor* begin_value = c->input_tensor(1);
1771       const Tensor* end_value = c->input_tensor(2);
1772 
1773       PartialTensorShape processing_shape, final_shape;
1774       bool is_identity, is_simple_slice, slice_dim0;
1775       gtl::InlinedVector<int64, 4> begin, end, strides;
1776       TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
1777           begin_value, end_value, *strides_value, input_shape, begin_mask,
1778           end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
1779           &processing_shape, &final_shape, &is_identity, &is_simple_slice,
1780           &slice_dim0, &begin, &end, &strides));
1781 
1782       ShapeHandle out;
1783       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(final_shape, &out));
1784       c->set_output(0, out);
1785 
1786       auto* shape_and_type = c->input_handle_shapes_and_types(0);
1787       if (shape_and_type) {
1788         c->set_output_handle_shapes_and_types(0, *shape_and_type);
1789       }
1790 
1791       return OkStatus();
1792     });
1793 
1794 REGISTER_OP("StridedSliceGrad")
1795     .Input("shape: Index")
1796     .Input("begin: Index")
1797     .Input("end: Index")
1798     .Input("strides: Index")
1799     .Input("dy: T")
1800     .Output("output: T")
1801     .Attr("T: type")
1802     .Attr("Index: {int32, int64}")
1803     .Attr("begin_mask: int = 0")
1804     .Attr("end_mask: int = 0")
1805     .Attr("ellipsis_mask: int = 0")
1806     .Attr("new_axis_mask: int = 0")
1807     .Attr("shrink_axis_mask: int = 0")
__anon38bbb0e82602(InferenceContext* c) 1808     .SetShapeFn([](InferenceContext* c) {
1809       ShapeHandle out;
1810       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1811       c->set_output(0, out);
1812       return OkStatus();
1813     });
1814 
1815 REGISTER_OP("StridedSliceAssign")
1816     .Input("ref: Ref(T)")
1817     .Input("begin: Index")
1818     .Input("end: Index")
1819     .Input("strides: Index")
1820     .Input("value: T")
1821     .Output("output_ref: Ref(T)")
1822     .Attr("T: type")
1823     .Attr("Index: {int32, int64}")
1824     .Attr("begin_mask: int = 0")
1825     .Attr("end_mask: int = 0")
1826     .Attr("ellipsis_mask: int = 0")
1827     .Attr("new_axis_mask: int = 0")
1828     .Attr("shrink_axis_mask: int = 0")
1829     .SetShapeFn(shape_inference::UnchangedShape);
1830 // TODO(aselle): Fix this documentation once StridedSliceAssign Supports
1831 // broadcasting.
1832 // --------------------------------------------------------------------------
1833 
1834 REGISTER_OP("ResourceStridedSliceAssign")
1835     .Input("ref: resource")
1836     .Input("begin: Index")
1837     .Input("end: Index")
1838     .Input("strides: Index")
1839     .Input("value: T")
1840     .Attr("T: type")
1841     .Attr("Index: {int32, int64}")
1842     .Attr("begin_mask: int = 0")
1843     .Attr("end_mask: int = 0")
1844     .Attr("ellipsis_mask: int = 0")
1845     .Attr("new_axis_mask: int = 0")
1846     .Attr("shrink_axis_mask: int = 0")
1847     .SetShapeFn(shape_inference::NoOutputs);
1848 
1849 REGISTER_OP("TensorStridedSliceUpdate")
1850     .Input("input: T")
1851     .Input("begin: Index")
1852     .Input("end: Index")
1853     .Input("strides: Index")
1854     .Input("value: T")
1855     .Output("output: T")
1856     .Attr("T: type")
1857     .Attr("Index: {int32, int64}")
1858     .Attr("begin_mask: int = 0")
1859     .Attr("end_mask: int = 0")
1860     .Attr("ellipsis_mask: int = 0")
1861     .Attr("new_axis_mask: int = 0")
1862     .Attr("shrink_axis_mask: int = 0")
1863     .SetShapeFn(shape_inference::UnchangedShape);
1864 
1865 REGISTER_OP("Tile")
1866     .Input("input: T")
1867     .Input("multiples: Tmultiples")
1868     .Output("output: T")
1869     .Attr("T: type")
1870     .Attr("Tmultiples: {int32, int64} = DT_INT32")
__anon38bbb0e82702(InferenceContext* c) 1871     .SetShapeFn([](InferenceContext* c) {
1872       ShapeHandle input = c->input(0);
1873       // NOTE(mrry): Represent `multiples` as a `TensorShape` because (i)
1874       // it is a vector of non-negative integers, and (ii) doing so allows
1875       // us to handle partially-known multiples.
1876       ShapeHandle multiples;
1877       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &multiples));
1878       if (c->RankKnown(input)) {
1879         TF_RETURN_IF_ERROR(c->WithRank(multiples, c->Rank(input), &multiples));
1880         ShapeHandle dummy;
1881         TF_RETURN_IF_ERROR(
1882             c->Merge(c->input(1), c->Vector(c->Rank(input)), &dummy));
1883       }
1884 
1885       if (!c->RankKnown(multiples)) {
1886         return shape_inference::UnknownShape(c);
1887       }
1888 
1889       int32_t rank = c->Rank(multiples);
1890       TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
1891       std::vector<DimensionHandle> dims(rank);
1892       for (int i = 0; i < rank; ++i) {
1893         TF_RETURN_IF_ERROR(
1894             c->Multiply(c->Dim(input, i), c->Dim(multiples, i), &dims[i]));
1895       }
1896       c->set_output(0, c->MakeShape(dims));
1897       return OkStatus();
1898     });
1899 
1900 // --------------------------------------------------------------------------
1901 REGISTER_OP("TileGrad")
1902     .Input("input: T")
1903     .Input("multiples: int32")
1904     .Output("output: T")
1905     .Attr("T: type")
1906     .Deprecated(3, "TileGrad has been replaced with reduce_sum")
1907     .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1908 
1909 // --------------------------------------------------------------------------
1910 REGISTER_OP("Where")
1911     .Input("input: T")
1912     .Attr("T: {numbertype, bool} = DT_BOOL")
1913     .Output("index: int64")
__anon38bbb0e82802(InferenceContext* c) 1914     .SetShapeFn([](InferenceContext* c) {
1915       c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0))));
1916       return OkStatus();
1917     });
1918 
1919 // --------------------------------------------------------------------------
1920 REGISTER_OP("BroadcastArgs")
1921     .Input("s0: T")
1922     .Input("s1: T")
1923     .Output("r0: T")
1924     .Attr("T: {int32, int64} = DT_INT32")
__anon38bbb0e82902(InferenceContext* c) 1925     .SetShapeFn([](InferenceContext* c) {
1926       ShapeHandle unused;
1927       ShapeHandle shape_x = c->input(0);
1928       ShapeHandle shape_y = c->input(1);
1929       TF_RETURN_IF_ERROR(c->WithRank(shape_x, 1, &unused));
1930       TF_RETURN_IF_ERROR(c->WithRank(shape_y, 1, &unused));
1931 
1932       if (!c->ValueKnown(c->Dim(shape_x, 0)) ||
1933           !c->ValueKnown(c->Dim(shape_y, 0))) {
1934         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1935         return OkStatus();
1936       }
1937 
1938       int64_t x_dim = c->Value(c->Dim(shape_x, 0));
1939       int64_t y_dim = c->Value(c->Dim(shape_y, 0));
1940 
1941       // Broadcasted shape is going to be as large as the largest dimension.
1942       c->set_output(0, c->Vector(std::max(x_dim, y_dim)));
1943       return OkStatus();
1944     });
1945 
1946 // --------------------------------------------------------------------------
1947 REGISTER_OP("BroadcastGradientArgs")
1948     .Input("s0: T")
1949     .Input("s1: T")
1950     .Output("r0: T")
1951     .Output("r1: T")
1952     .Attr("T: {int32, int64} = DT_INT32")
__anon38bbb0e82a02(InferenceContext* c) 1953     .SetShapeFn([](InferenceContext* c) {
1954       // TODO(mrry): Implement constant_value for BroadcastGradientArgs?
1955       ShapeHandle unused;
1956       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1957       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1958       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1959       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
1960       return OkStatus();
1961     });
1962 
1963 // --------------------------------------------------------------------------
1964 REGISTER_OP("Pad")
1965     .Input("input: T")
1966     .Input("paddings: Tpaddings")
1967     .Output("output: T")
1968     .Attr("T: type")
1969     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1970     .SetShapeFn(PadShapeFn);
1971 
1972 // --------------------------------------------------------------------------
1973 REGISTER_OP("PadV2")
1974     .Input("input: T")
1975     .Input("paddings: Tpaddings")
1976     .Input("constant_values: T")
1977     .Output("output: T")
1978     .Attr("T: type")
1979     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1980     .SetShapeFn(PadShapeFn);
1981 
1982 // --------------------------------------------------------------------------
1983 REGISTER_OP("MirrorPad")
1984     .Input("input: T")
1985     .Input("paddings: Tpaddings")
1986     .Output("output: T")
1987     .Attr("T: type")
1988     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1989     .Attr(GetMirrorPadModeAttrString())
1990     .SetShapeFn(PadShapeFn);
1991 
1992 // --------------------------------------------------------------------------
1993 namespace {
1994 template <typename T>
MirrorPadKnown(InferenceContext * c,ShapeHandle input,const Tensor * paddings_t,int64_t input_rank)1995 Status MirrorPadKnown(InferenceContext* c, ShapeHandle input,
1996                       const Tensor* paddings_t, int64_t input_rank) {
1997   auto paddings_data = paddings_t->matrix<T>();
1998   std::vector<DimensionHandle> dims(input_rank);
1999   for (int64_t i = 0; i < input_rank; ++i) {
2000     const int64_t pad0 = static_cast<int64_t>(paddings_data(i, 0));
2001     const int64_t pad1 = static_cast<int64_t>(paddings_data(i, 1));
2002     if (pad0 < 0 || pad1 < 0) {
2003       return errors::InvalidArgument("Paddings must be non-negative");
2004     }
2005 
2006     TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
2007   }
2008   c->set_output(0, c->MakeShape(dims));
2009   return OkStatus();
2010 }
2011 
2012 }  // namespace
2013 
2014 REGISTER_OP("MirrorPadGrad")
2015     .Input("input: T")
2016     .Input("paddings: Tpaddings")
2017     .Output("output: T")
2018     .Attr("T: type")
2019     .Attr("Tpaddings: {int32, int64} = DT_INT32")
2020     .Attr(GetMirrorPadModeAttrString())
__anon38bbb0e82c02(InferenceContext* c) 2021     .SetShapeFn([](InferenceContext* c) {
2022       ShapeHandle paddings;
2023       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
2024       DimensionHandle pad_0 = c->Dim(paddings, 0);
2025       if (!c->ValueKnown(pad_0)) {
2026         // We don't know the rank of the output since the first
2027         // padding dimension is unknown.
2028         c->set_output(0, c->UnknownShape());
2029         return OkStatus();
2030       }
2031 
2032       int64_t input_rank = c->Value(pad_0);
2033       ShapeHandle input;
2034       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), input_rank, &input));
2035       TF_RETURN_IF_ERROR(
2036           c->Merge(paddings, c->Matrix(input_rank, 2), &paddings));
2037 
2038       const Tensor* paddings_t = c->input_tensor(1);
2039       if (paddings_t == nullptr) {
2040         // Values of 'paddings' is not available, but we know the
2041         // input rank, so return the rank of the output with unknown
2042         // dimensions.
2043         c->set_output(0, c->UnknownShapeOfRank(input_rank));
2044         return OkStatus();
2045       }
2046 
2047       if (paddings_t->dtype() == DT_INT32) {
2048         return MirrorPadKnown<int32>(c, input, paddings_t, input_rank);
2049       } else {
2050         return MirrorPadKnown<int64_t>(c, input, paddings_t, input_rank);
2051       }
2052     });
2053 
2054 // --------------------------------------------------------------------------
2055 REGISTER_OP("Placeholder")
2056     .Output("output: dtype")
2057     .Attr("dtype: type")
2058     .Attr("shape: shape = { unknown_rank: true }")
__anon38bbb0e82d02(InferenceContext* c) 2059     .SetShapeFn([](InferenceContext* c) {
2060       PartialTensorShape shape;
2061       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2062 
2063       // Placeholder has legacy behavior where we cannot tell the difference
2064       // between a scalar shape attribute and 'unknown shape'.  So if the shape
2065       // is a scalar, we return an unknown shape.
2066       if (c->graph_def_version() <= 21 && shape.dims() <= 0) {
2067         return shape_inference::UnknownShape(c);
2068       }
2069 
2070       ShapeHandle out;
2071       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
2072       c->set_output(0, out);
2073       return OkStatus();
2074     });
2075 
2076 // Placeholder was modified in a backwards compatible way to do what
2077 // PlaceholderV2 did, so we have deprecated V2 (no one was really
2078 // using it).
2079 REGISTER_OP("PlaceholderV2")
2080     .Output("output: dtype")
2081     .Attr("dtype: type")
2082     .Attr("shape: shape")
2083     .SetShapeFn(shape_inference::ExplicitShape)
2084     .Deprecated(23, "Placeholder now behaves the same as PlaceholderV2.");
2085 
2086 // --------------------------------------------------------------------------
2087 REGISTER_OP("PlaceholderWithDefault")
2088     .Input("input: dtype")
2089     .Output("output: dtype")
2090     .Attr("dtype: type")
2091     .Attr("shape: shape")
__anon38bbb0e82e02(InferenceContext* c) 2092     .SetShapeFn([](InferenceContext* c) {
2093       ShapeHandle input = c->input(0);
2094       PartialTensorShape shape;
2095       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2096       ShapeHandle out;
2097       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
2098 
2099       // We merge for compatibility checking, but return the output,
2100       // since output_shape may be less precise than input_shape.
2101       ShapeHandle unused;
2102       TF_RETURN_IF_ERROR(c->Merge(input, out, &unused));
2103       c->set_output(0, out);
2104       return OkStatus();
2105     });
2106 
2107 // --------------------------------------------------------------------------
2108 REGISTER_OP("ExpandDims")
2109     .Input("input: T")
2110     .Input("dim: Tdim")
2111     .Output("output: T")
2112     .Attr("T: type")
2113     .Attr("Tdim: {int32, int64} = DT_INT32")
__anon38bbb0e82f02(InferenceContext* c) 2114     .SetShapeFn([](InferenceContext* c) {
2115       ShapeHandle input = c->input(0);
2116 
2117       const Tensor* dim_t = c->input_tensor(1);
2118       if (dim_t != nullptr && dim_t->NumElements() != 1) {
2119         return errors::InvalidArgument(
2120             "'dim' input must be a tensor with a single value");
2121       }
2122       if (dim_t == nullptr || !c->RankKnown(input)) {
2123         c->set_output(0, c->UnknownShape());
2124         return OkStatus();
2125       }
2126 
2127       int64_t dim;
2128       if (dim_t->dtype() == DT_INT32) {
2129         dim = static_cast<int64_t>(dim_t->flat<int32>()(0));
2130       } else {
2131         dim = dim_t->flat<int64_t>()(0);
2132       }
2133 
2134       const int32_t rank = c->Rank(input);
2135       const int32_t min_dim = -1 * rank - 1;
2136       if (dim < min_dim || dim > rank) {
2137         return errors::InvalidArgument("dim ", dim, " not in the interval [",
2138                                        min_dim, ", ", rank, "].");
2139       }
2140 
2141       if (dim < 0) {
2142         dim += rank + 1;
2143       }
2144 
2145       ShapeHandle end;
2146       TF_RETURN_IF_ERROR(c->Subshape(input, dim, &end));
2147 
2148       // Build output as start + 1 + end.
2149       ShapeHandle output;
2150       TF_RETURN_IF_ERROR(c->Subshape(input, 0, dim, &output));
2151       TF_RETURN_IF_ERROR(c->Concatenate(output, c->Vector(1), &output));
2152       TF_RETURN_IF_ERROR(c->Concatenate(output, end, &output));
2153       c->set_output(0, output);
2154       return OkStatus();
2155     });
2156 
2157 // --------------------------------------------------------------------------
2158 REGISTER_OP("Squeeze")
2159     .Input("input: T")
2160     .Output("output: T")
2161     .Attr("T: type")
2162     .Attr("squeeze_dims: list(int) >= 0 = []")
__anon38bbb0e83002(InferenceContext* c) 2163     .SetShapeFn([](InferenceContext* c) {
2164       ShapeHandle input = c->input(0);
2165       if (!c->RankKnown(input)) {
2166         // Input shape unknown.
2167         return shape_inference::UnknownShape(c);
2168       }
2169 
2170       const int32_t input_rank = c->Rank(input);
2171 
2172       // Validate and wrap squeeze dimensions.
2173       std::vector<int32> squeeze_dims;
2174       TF_RETURN_IF_ERROR(c->GetAttr("squeeze_dims", &squeeze_dims));
2175       for (int i = 0; i < squeeze_dims.size(); ++i) {
2176         if (squeeze_dims[i] < -input_rank || squeeze_dims[i] >= input_rank) {
2177           return errors::InvalidArgument("squeeze_dims[", i, "] not in [",
2178                                          -input_rank, ",", input_rank, ").");
2179         }
2180 
2181         if (squeeze_dims[i] < 0) {
2182           squeeze_dims[i] += input_rank;
2183         }
2184       }
2185 
2186       std::vector<DimensionHandle> result_shape;
2187       for (int i = 0; i < input_rank; ++i) {
2188         // True if squeeze_dims contains an entry to squeeze this
2189         // dimension.
2190         bool is_explicit_match =
2191             std::find(squeeze_dims.begin(), squeeze_dims.end(), i) !=
2192             squeeze_dims.end();
2193 
2194         DimensionHandle dim = c->Dim(input, i);
2195 
2196         if (!c->ValueKnown(dim)) {
2197           // Assume that the squeezed dimension will be 1 at runtime.
2198           if (is_explicit_match) continue;
2199 
2200           // If squeezing all 1 dimensions, and we see an unknown value,
2201           // give up and return Unknown Shape.
2202           if (squeeze_dims.empty()) {
2203             c->set_output(0, c->UnknownShape());
2204             return OkStatus();
2205           }
2206         } else if (c->Value(dim) == 1) {
2207           if (is_explicit_match || squeeze_dims.empty()) {
2208             // If explicitly squeezing, or squeezing all 1s, remove
2209             // this dimension.
2210             continue;
2211           }
2212         } else if (is_explicit_match) {
2213           return errors::InvalidArgument("Can not squeeze dim[", i,
2214                                          "], expected a dimension of 1, got ",
2215                                          c->Value(c->Dim(input, i)));
2216         }
2217 
2218         result_shape.emplace_back(dim);
2219       }
2220 
2221       c->set_output(0, c->MakeShape(result_shape));
2222       return OkStatus();
2223     });
2224 
2225 // --------------------------------------------------------------------------
2226 REGISTER_OP("ListDiff")
2227     .Input("x: T")
2228     .Input("y: T")
2229     .Output("out: T")
2230     .Output("idx: out_idx")
2231     .Attr("T: type")
2232     .Attr("out_idx: {int32, int64} = DT_INT32")
__anon38bbb0e83102(InferenceContext* c) 2233     .SetShapeFn([](InferenceContext* c) {
2234       ShapeHandle unused;
2235       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
2236       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2237       // TODO(mrry): Indicate that the length falls within an interval?
2238       ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
2239       c->set_output(0, out);
2240       c->set_output(1, out);
2241       return OkStatus();
2242     });
2243 
2244 namespace {
2245 
2246 // Converts Tensor to flat std::vector<int64_t>.
2247 template <typename InputType>
GetFlatInt64(const Tensor & t)2248 std::vector<int64_t> GetFlatInt64(const Tensor& t) {
2249   std::vector<int64_t> output(t.shape().num_elements());
2250   if (t.shape().num_elements() > 0) {
2251     auto eigen_vec = t.flat<InputType>();
2252     std::copy_n(&eigen_vec(0), output.size(), output.begin());
2253   }
2254   return output;
2255 }
2256 
2257 // Converts int32 or int64 Tensor to flat std::vector<int64_t>.
GetFlatInt64(const Tensor & t)2258 std::vector<int64_t> GetFlatInt64(const Tensor& t) {
2259   if (t.dtype() == DT_INT32) {
2260     return GetFlatInt64<int32>(t);
2261   } else {
2262     return GetFlatInt64<int64_t>(t);
2263   }
2264 }
2265 
SpaceToBatchShapeHelper(InferenceContext * c,ShapeHandle input_shape,ShapeHandle block_shape_shape,const Tensor * block_shape_t,ShapeHandle paddings_shape,const Tensor * paddings_t)2266 Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2267                                ShapeHandle block_shape_shape,
2268                                const Tensor* block_shape_t,
2269                                ShapeHandle paddings_shape,
2270                                const Tensor* paddings_t) {
2271   if (c->Rank(block_shape_shape) != 1) {
2272     return errors::InvalidArgument("block_shape must have rank 1.");
2273   }
2274 
2275   const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2276   if (!c->ValueKnown(num_block_dims_handle)) {
2277     return errors::InvalidArgument("block_shape must have known size.");
2278   }
2279 
2280   const int64_t num_block_dims = c->Value(num_block_dims_handle);
2281 
2282   TF_RETURN_IF_ERROR(
2283       c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2284 
2285   TF_RETURN_IF_ERROR(
2286       c->Merge(paddings_shape, c->Matrix(num_block_dims, 2), &paddings_shape));
2287 
2288   DimensionHandle batch_size = c->Dim(input_shape, 0);
2289   std::vector<int64_t> block_shape_vec;
2290   if (block_shape_t && (block_shape_t->NumElements() > 0)) {
2291     block_shape_vec = GetFlatInt64(*block_shape_t);
2292     for (int64_t dim = 0; dim < num_block_dims; ++dim) {
2293       const int64_t block_shape_value = block_shape_vec[dim];
2294       if (block_shape_value < 1) {
2295         return errors::InvalidArgument("block_shape must be positive");
2296       }
2297       if (c->ValueKnown(batch_size)) {
2298         TF_RETURN_IF_ERROR(
2299             c->Multiply(batch_size, block_shape_value, &batch_size));
2300       } else {
2301         batch_size = c->UnknownDim();
2302       }
2303     }
2304   } else if (num_block_dims > 0) {
2305     batch_size = c->UnknownDim();
2306   }
2307 
2308   std::vector<DimensionHandle> output_dims{batch_size};
2309   output_dims.resize(num_block_dims + 1, c->UnknownDim());
2310 
2311   if (paddings_t && (paddings_t->NumElements() > 0)) {
2312     const std::vector<int64_t> paddings_vec = GetFlatInt64(*paddings_t);
2313     for (int64_t dim = 0; dim < num_block_dims; ++dim) {
2314       const int64_t pad_start = paddings_vec[dim * 2],
2315                     pad_end = paddings_vec[dim * 2 + 1];
2316       if (pad_start < 0 || pad_end < 0) {
2317         return errors::InvalidArgument("paddings cannot be negative");
2318       }
2319       if (block_shape_t) {
2320         DimensionHandle padded_size;
2321         TF_RETURN_IF_ERROR(
2322             c->Add(c->Dim(input_shape, dim + 1), pad_start, &padded_size));
2323         TF_RETURN_IF_ERROR(c->Add(padded_size, pad_end, &padded_size));
2324         TF_RETURN_IF_ERROR(c->Divide(padded_size, block_shape_vec[dim],
2325                                      /*evenly_divisible=*/true,
2326                                      &output_dims[dim + 1]));
2327       }
2328     }
2329   }
2330 
2331   ShapeHandle remaining_input_shape;
2332   TF_RETURN_IF_ERROR(
2333       c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2334 
2335   ShapeHandle result;
2336   TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2337                                     remaining_input_shape, &result));
2338   c->set_output(0, result);
2339   return OkStatus();
2340 }
2341 
BatchToSpaceShapeHelper(InferenceContext * c,ShapeHandle input_shape,ShapeHandle block_shape_shape,const Tensor * block_shape_t,ShapeHandle crops_shape,const Tensor * crops_t)2342 Status BatchToSpaceShapeHelper(InferenceContext* c, ShapeHandle input_shape,
2343                                ShapeHandle block_shape_shape,
2344                                const Tensor* block_shape_t,
2345                                ShapeHandle crops_shape, const Tensor* crops_t) {
2346   if (c->Rank(block_shape_shape) != 1) {
2347     return errors::InvalidArgument("block_shape must have rank 1.");
2348   }
2349 
2350   const DimensionHandle num_block_dims_handle = c->Dim(block_shape_shape, 0);
2351   if (!c->ValueKnown(num_block_dims_handle)) {
2352     return errors::InvalidArgument("block_shape must have known size.");
2353   }
2354 
2355   const int64_t num_block_dims = c->Value(num_block_dims_handle);
2356 
2357   TF_RETURN_IF_ERROR(
2358       c->WithRankAtLeast(input_shape, num_block_dims + 1, &input_shape));
2359 
2360   TF_RETURN_IF_ERROR(
2361       c->Merge(crops_shape, c->Matrix(num_block_dims, 2), &crops_shape));
2362 
2363   DimensionHandle batch_size = c->Dim(input_shape, 0);
2364   std::vector<int64_t> block_shape_vec;
2365   if (block_shape_t) {
2366     block_shape_vec = GetFlatInt64(*block_shape_t);
2367     for (int64_t dim = 0; dim < num_block_dims; ++dim) {
2368       const int64_t block_shape_value = block_shape_vec[dim];
2369       if (block_shape_value < 1) {
2370         return errors::InvalidArgument("block_shape must be positive");
2371       }
2372       if (c->ValueKnown(batch_size)) {
2373         TF_RETURN_IF_ERROR(c->Divide(batch_size, block_shape_value,
2374                                      /*evenly_divisible=*/true, &batch_size));
2375       } else {
2376         batch_size = c->UnknownDim();
2377       }
2378     }
2379   } else if (num_block_dims > 0) {
2380     batch_size = c->UnknownDim();
2381   }
2382 
2383   std::vector<DimensionHandle> output_dims{batch_size};
2384   output_dims.resize(num_block_dims + 1, c->UnknownDim());
2385 
2386   if (crops_t) {
2387     const std::vector<int64_t> crops_vec = GetFlatInt64(*crops_t);
2388     for (int64_t dim = 0; dim < num_block_dims; ++dim) {
2389       const int64_t crop_start = crops_vec[dim * 2],
2390                     crop_end = crops_vec[dim * 2 + 1];
2391       if (crop_start < 0 || crop_end < 0) {
2392         return errors::InvalidArgument("crops cannot be negative");
2393       }
2394       if (block_shape_t) {
2395         DimensionHandle cropped_size;
2396         TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, dim + 1),
2397                                        block_shape_vec[dim], &cropped_size));
2398         TF_RETURN_IF_ERROR(
2399             c->Subtract(cropped_size, crop_start, &cropped_size));
2400         TF_RETURN_IF_ERROR(
2401             c->Subtract(cropped_size, crop_end, &output_dims[dim + 1]));
2402       }
2403     }
2404   }
2405 
2406   ShapeHandle remaining_input_shape;
2407   TF_RETURN_IF_ERROR(
2408       c->Subshape(input_shape, 1 + num_block_dims, &remaining_input_shape));
2409 
2410   ShapeHandle result;
2411   TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape(output_dims),
2412                                     remaining_input_shape, &result));
2413   c->set_output(0, result);
2414   return OkStatus();
2415 }
2416 
2417 }  // namespace
2418 
2419 // --------------------------------------------------------------------------
2420 REGISTER_OP("SpaceToBatchND")
2421     .Input("input: T")
2422     .Input("block_shape: Tblock_shape")
2423     .Input("paddings: Tpaddings")
2424     .Output("output: T")
2425     .Attr("T: type")
2426     .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2427     .Attr("Tpaddings: {int32, int64} = DT_INT32")
__anon38bbb0e83302(InferenceContext* c) 2428     .SetShapeFn([](InferenceContext* c) {
2429       return SpaceToBatchShapeHelper(c, c->input(0), c->input(1),
2430                                      c->input_tensor(1), c->input(2),
2431                                      c->input_tensor(2));
2432     });
2433 
2434 // --------------------------------------------------------------------------
2435 REGISTER_OP("SpaceToBatch")
2436     .Input("input: T")
2437     .Input("paddings: Tpaddings")
2438     .Output("output: T")
2439     .Attr("T: type")
2440     .Attr("Tpaddings: {int32, int64} = DT_INT32")
2441     .Attr("block_size: int >= 2")
__anon38bbb0e83402(InferenceContext* c) 2442     .SetShapeFn([](InferenceContext* c) {
2443       ShapeHandle input_shape;
2444       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2445 
2446       int32_t block_size;
2447       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2448 
2449       Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2450       auto block_shape_vec = block_shape.vec<int64_t>();
2451       block_shape_vec(0) = block_size;
2452       block_shape_vec(1) = block_size;
2453 
2454       return SpaceToBatchShapeHelper(c, input_shape, c->MakeShape({2}),
2455                                      &block_shape, c->input(1),
2456                                      c->input_tensor(1));
2457     });
2458 
2459 // --------------------------------------------------------------------------
2460 REGISTER_OP("BatchToSpaceND")
2461     .Input("input: T")
2462     .Input("block_shape: Tblock_shape")
2463     .Input("crops: Tcrops")
2464     .Output("output: T")
2465     .Attr("T: type")
2466     .Attr("Tblock_shape: {int32, int64} = DT_INT32")
2467     .Attr("Tcrops: {int32, int64} = DT_INT32")
__anon38bbb0e83502(InferenceContext* c) 2468     .SetShapeFn([](InferenceContext* c) {
2469       return BatchToSpaceShapeHelper(c, c->input(0), c->input(1),
2470                                      c->input_tensor(1), c->input(2),
2471                                      c->input_tensor(2));
2472     });
2473 
2474 // --------------------------------------------------------------------------
2475 REGISTER_OP("BatchToSpace")
2476     .Input("input: T")
2477     .Input("crops: Tidx")
2478     .Output("output: T")
2479     .Attr("T: type")
2480     .Attr("block_size: int >= 2")
2481     .Attr("Tidx: {int32, int64} = DT_INT32")
__anon38bbb0e83602(InferenceContext* c) 2482     .SetShapeFn([](InferenceContext* c) {
2483       ShapeHandle input_shape;
2484       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2485 
2486       int32_t block_size;
2487       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2488 
2489       Tensor block_shape(tensorflow::DT_INT64, TensorShape({2}));
2490       auto block_shape_vec = block_shape.vec<int64_t>();
2491       block_shape_vec(0) = block_size;
2492       block_shape_vec(1) = block_size;
2493 
2494       return BatchToSpaceShapeHelper(c, input_shape, c->MakeShape({2}),
2495                                      &block_shape, c->input(1),
2496                                      c->input_tensor(1));
2497     });
2498 
2499 // --------------------------------------------------------------------------
2500 REGISTER_OP("SpaceToDepth")
2501     .Input("input: T")
2502     .Output("output: T")
2503     .Attr("T: type")
2504     .Attr("block_size: int >= 2")
2505     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2506     // TODO(pauldonnelly): Implement GPU kernels for NCHW_VECT_C.
__anon38bbb0e83702(InferenceContext* c) 2507     .SetShapeFn([](InferenceContext* c) {
2508       string data_format_str;
2509       TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2510       TensorFormat data_format;
2511       FormatFromString(data_format_str, &data_format);
2512 
2513       constexpr int num_spatial_dims = 2;
2514       const int dims =
2515           GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2516       ShapeHandle input;
2517       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2518 
2519       int32_t block_size;
2520       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2521 
2522       DimensionHandle batch_size =
2523           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2524       DimensionHandle input_height =
2525           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2526       DimensionHandle input_width =
2527           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2528       DimensionHandle input_depth =
2529           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2530 
2531       DimensionHandle output_height;
2532       DimensionHandle output_width;
2533       DimensionHandle output_depth;
2534       // Will return an error if input height or width are not evenly divisible.
2535       TF_RETURN_IF_ERROR(c->Divide(input_height, block_size,
2536                                    true /* evenly_divisible */,
2537                                    &output_height));
2538       TF_RETURN_IF_ERROR(c->Divide(input_width, block_size,
2539                                    true /* evenly_divisible */, &output_width));
2540 
2541       TF_RETURN_IF_ERROR(
2542           c->Multiply(input_depth, block_size * block_size, &output_depth));
2543 
2544       ShapeHandle output_shape;
2545       TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2546                                              {output_height, output_width},
2547                                              output_depth, &output_shape, c));
2548 
2549       c->set_output(0, output_shape);
2550       return OkStatus();
2551     });
2552 
2553 // --------------------------------------------------------------------------
2554 REGISTER_OP("DepthToSpace")
2555     .Input("input: T")
2556     .Output("output: T")
2557     .Attr("T: type")
2558     .Attr("block_size: int >= 2")
2559     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
2560     // TODO(pauldonnelly): Implement GPU kernels for NCHW and NCHW_VECT_C.
__anon38bbb0e83802(InferenceContext* c) 2561     .SetShapeFn([](InferenceContext* c) {
2562       string data_format_str;
2563       TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
2564       TensorFormat data_format;
2565       FormatFromString(data_format_str, &data_format);
2566 
2567       constexpr int num_spatial_dims = 2;
2568       const int dims =
2569           GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
2570 
2571       ShapeHandle input;
2572       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
2573 
2574       int32_t block_size;
2575       TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2576 
2577       DimensionHandle batch_size =
2578           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
2579       DimensionHandle input_height =
2580           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
2581       DimensionHandle input_width =
2582           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
2583       DimensionHandle input_depth =
2584           c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
2585 
2586       DimensionHandle output_height;
2587       DimensionHandle output_width;
2588       DimensionHandle output_depth;
2589       TF_RETURN_IF_ERROR(c->Multiply(input_height, block_size, &output_height));
2590       TF_RETURN_IF_ERROR(c->Multiply(input_width, block_size, &output_width));
2591 
2592       // Will return an error if input_depth is not evenly divisible.
2593       TF_RETURN_IF_ERROR(c->Divide(input_depth, block_size * block_size,
2594                                    true /* evenly_divisible */, &output_depth));
2595 
2596       ShapeHandle output_shape;
2597       TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size,
2598                                              {output_height, output_width},
2599                                              output_depth, &output_shape, c));
2600 
2601       c->set_output(0, output_shape);
2602       return OkStatus();
2603     });
2604 
2605 // --------------------------------------------------------------------------
2606 
2607 REGISTER_OP("ExtractImagePatches")
2608     .Input("images: T")
2609     .Output("patches: T")
2610     .Attr("ksizes: list(int) >= 4")
2611     .Attr("strides: list(int) >= 4")
2612     .Attr("rates: list(int) >= 4")
2613     .Attr(
2614         "T: {bfloat16, half, float, double, int8, int16, int32, int64, "
2615         "uint8, uint16, uint32, uint64, complex64, complex128, bool}")
2616     .Attr(GetPaddingAttrString())
__anon38bbb0e83902(InferenceContext* c) 2617     .SetShapeFn([](InferenceContext* c) {
2618       ShapeHandle input_shape;
2619       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2620 
2621       std::vector<int32> ksizes;
2622       TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2623       if (ksizes.size() != 4) {
2624         return errors::InvalidArgument(
2625             "ExtractImagePatches requires the ksizes attribute to contain 4 "
2626             "values, but got: ",
2627             ksizes.size());
2628       }
2629 
2630       std::vector<int32> strides;
2631       TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2632       if (strides.size() != 4) {
2633         return errors::InvalidArgument(
2634             "ExtractImagePatches requires the stride attribute to contain 4 "
2635             "values, but got: ",
2636             strides.size());
2637       }
2638 
2639       std::vector<int32> rates;
2640       TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2641       if (rates.size() != 4) {
2642         return errors::InvalidArgument(
2643             "ExtractImagePatches requires the rates attribute to contain 4 "
2644             "values, but got: ",
2645             rates.size());
2646       }
2647 
2648       int32_t ksize_rows = ksizes[1];
2649       int32_t ksize_cols = ksizes[2];
2650 
2651       int32_t stride_rows = strides[1];
2652       int32_t stride_cols = strides[2];
2653 
2654       int32_t rate_rows = rates[1];
2655       int32_t rate_cols = rates[2];
2656 
2657       int32_t ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2658       int32_t ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2659 
2660       DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2661       DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
2662       DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
2663       DimensionHandle output_depth_dim;
2664       TF_RETURN_IF_ERROR(c->Multiply(
2665           c->Dim(input_shape, 3), ksize_rows * ksize_cols, &output_depth_dim));
2666 
2667       if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim)) {
2668         ShapeHandle output_shape =
2669             c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2670                           InferenceContext::kUnknownDim, output_depth_dim});
2671         c->set_output(0, output_shape);
2672         return OkStatus();
2673       }
2674       auto in_rows = c->Value(in_rows_dim);
2675       auto in_cols = c->Value(in_cols_dim);
2676 
2677       Padding padding;
2678       TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2679 
2680       int64_t output_rows, output_cols;
2681       int64_t padding_before, padding_after;
2682       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2683           in_rows, ksize_rows_eff, stride_rows, padding, &output_rows,
2684           &padding_before, &padding_after));
2685       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2686           in_cols, ksize_cols_eff, stride_cols, padding, &output_cols,
2687           &padding_before, &padding_after));
2688       ShapeHandle output_shape = c->MakeShape(
2689           {batch_size_dim, output_rows, output_cols, output_depth_dim});
2690       c->set_output(0, output_shape);
2691       return OkStatus();
2692     });
2693 
2694 // --------------------------------------------------------------------------
2695 
2696 // To enable rates, uncomment all lines commented below and use ksize_*_eff
2697 // as the second parameter of all GetWindowedOutputSizeVerbose calls instead
2698 // of ksize_*.
2699 REGISTER_OP("ExtractVolumePatches")
2700     .Input("input: T")
2701     .Output("patches: T")
2702     .Attr("ksizes: list(int) >= 5")
2703     .Attr("strides: list(int) >= 5")
2704     /* .Attr("rates: list(int) >= 5") */
2705     .Attr("T: realnumbertype")
2706     .Attr(GetPaddingAttrString())
__anon38bbb0e83a02(InferenceContext* c) 2707     .SetShapeFn([](InferenceContext* c) {
2708       ShapeHandle input_shape;
2709       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
2710 
2711       std::vector<int32> ksizes;
2712       TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2713       if (ksizes.size() != 5) {
2714         return errors::InvalidArgument(
2715             "ExtractVolumePatches requires the ksizes attribute to contain 5 "
2716             "values, but got: ",
2717             ksizes.size());
2718       }
2719 
2720       std::vector<int32> strides;
2721       TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2722       if (strides.size() != 5) {
2723         return errors::InvalidArgument(
2724             "ExtractVolumePatches requires the stride attribute to contain 5 "
2725             "values, but got: ",
2726             strides.size());
2727       }
2728 
2729       /*
2730       // TODO(hsgkim): Enable rates.
2731       // See extract_volume_patches_op.cc for why rates are disabled now.
2732 
2733       std::vector<int32> rates;
2734       TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2735       if (rates.size() != 5) {
2736         return errors::InvalidArgument(
2737             "ExtractVolumePatches requires the rates attribute to contain 5 "
2738             "values, but got: ",
2739             rates.size());
2740       }
2741       */
2742 
2743       int32_t ksize_planes = ksizes[1];
2744       int32_t ksize_rows = ksizes[2];
2745       int32_t ksize_cols = ksizes[3];
2746 
2747       int32_t stride_planes = strides[1];
2748       int32_t stride_rows = strides[2];
2749       int32_t stride_cols = strides[3];
2750 
2751       /*
2752       int32 rate_planes = rates[1];
2753       int32 rate_rows = rates[2];
2754       int32 rate_cols = rates[3];
2755 
2756       int32 ksize_planes_eff = ksize_planes +
2757                                (ksize_planes - 1) * (rate_planes - 1);
2758       int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2759       int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2760       */
2761 
2762       DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
2763       DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
2764       DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
2765       DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
2766       DimensionHandle output_depth_dim;
2767       TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
2768                                      ksize_planes * ksize_rows * ksize_cols,
2769                                      &output_depth_dim));
2770 
2771       if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
2772           !c->ValueKnown(in_cols_dim)) {
2773         ShapeHandle output_shape =
2774             c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
2775                           InferenceContext::kUnknownDim, output_depth_dim});
2776         c->set_output(0, output_shape);
2777         return OkStatus();
2778       }
2779       auto in_planes = c->Value(in_planes_dim);
2780       auto in_rows = c->Value(in_rows_dim);
2781       auto in_cols = c->Value(in_cols_dim);
2782 
2783       Padding padding;
2784       TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2785 
2786       int64_t output_planes, output_rows, output_cols;
2787       int64_t padding_before, padding_after;
2788       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2789           in_planes, ksize_planes, stride_planes, padding, &output_planes,
2790           &padding_before, &padding_after));
2791       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2792           in_rows, ksize_rows, stride_rows, padding, &output_rows,
2793           &padding_before, &padding_after));
2794       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2795           in_cols, ksize_cols, stride_cols, padding, &output_cols,
2796           &padding_before, &padding_after));
2797       ShapeHandle output_shape =
2798           c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
2799                         output_depth_dim});
2800       c->set_output(0, output_shape);
2801       return OkStatus();
2802     });
2803 
2804 // --------------------------------------------------------------------------
2805 
2806 REGISTER_OP("OneHot")
2807     .Input("indices: TI")
2808     .Input("depth: int32")
2809     .Input("on_value: T")
2810     .Input("off_value: T")
2811     .Attr("axis: int = -1")
2812     .Output("output: T")
2813     .Attr("T: type")
2814     .Attr("TI: {uint8, int32, int64} = DT_INT64")
__anon38bbb0e83b02(InferenceContext* c) 2815     .SetShapeFn([](InferenceContext* c) {
2816       int32_t axis;
2817       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2818       if (axis < -1) return errors::InvalidArgument("axis must be >= -1");
2819 
2820       DimensionHandle depth;
2821       TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth));
2822 
2823       ShapeHandle indices = c->input(0);
2824       if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c);
2825 
2826       int32_t new_rank = c->Rank(indices) + 1;
2827       // We need to add new_rank to axis in the case the axis is -1 because
2828       // C++ returns negative values from % if the dividend is negative.
2829       int32_t depth_index = (axis + new_rank) % new_rank;
2830       // Out shape is indices[0:depth_index] + [depth] + indices[depth_index:].
2831       ShapeHandle front;
2832       ShapeHandle back;
2833       ShapeHandle out;
2834       TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front));
2835       TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back));
2836       TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front));
2837       TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out));
2838       c->set_output(0, out);
2839       return OkStatus();
2840     });
2841 
2842 // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
2843 REGISTER_OP("QuantizeAndDequantize")
2844     .Input("input: T")
2845     .Attr("signed_input: bool = true")
2846     .Attr("num_bits: int = 8")
2847     .Attr("range_given: bool = false")
2848     .Attr("input_min: float = 0")
2849     .Attr("input_max: float = 0")
2850     .Output("output: T")
2851     .Attr("T: {bfloat16, half, float, double}")
2852     .SetShapeFn(shape_inference::UnchangedShape)
2853     .Deprecated(22, "Replaced by QuantizeAndDequantizeV2");
2854 
2855 // TODO(suharshs): Deprecate QuantizeAndDequantizeV2.
2856 REGISTER_OP("QuantizeAndDequantizeV2")
2857     .Input("input: T")
2858     .Input("input_min: T")
2859     .Input("input_max: T")
2860     .Attr("signed_input: bool = true")
2861     .Attr("num_bits: int = 8")
2862     .Attr("range_given: bool = false")
2863     .Output("output: T")
2864     .Attr("T: {bfloat16, half, float, double}")
2865     .Attr(
2866         "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
2867         "'HALF_TO_EVEN'")
2868     .Attr("narrow_range: bool = false")
2869     .Attr("axis: int = -1")
__anon38bbb0e83c02(InferenceContext* c) 2870     .SetShapeFn([](InferenceContext* c) {
2871       int axis;
2872       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2873       const int minmax_rank = (axis == -1) ? 0 : 1;
2874       ShapeHandle minmax;
2875       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2876       TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2877       if (axis < -1) {
2878         return errors::InvalidArgument("axis should be at least -1, got ",
2879                                        axis);
2880       } else if (axis != -1) {
2881         ShapeHandle input;
2882         TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2883         DimensionHandle depth;
2884         TF_RETURN_IF_ERROR(
2885             c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2886       }
2887       c->set_output(0, c->input(0));
2888       return OkStatus();
2889     });
2890 
2891 REGISTER_OP("QuantizeAndDequantizeV4")
2892     .Input("input: T")
2893     .Input("input_min: T")
2894     .Input("input_max: T")
2895     .Attr("signed_input: bool = true")
2896     .Attr("num_bits: int = 8")
2897     .Attr("range_given: bool = false")
2898     .Output("output: T")
2899     .Attr("T: {bfloat16, half, float, double}")
2900     .Attr(
2901         "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
2902         "'HALF_TO_EVEN'")
2903     .Attr("narrow_range: bool = false")
2904     .Attr("axis: int = -1")
__anon38bbb0e83d02(InferenceContext* c) 2905     .SetShapeFn([](InferenceContext* c) {
2906       int axis;
2907       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2908       const int minmax_rank = (axis == -1) ? 0 : 1;
2909       ShapeHandle minmax;
2910       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2911       TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2912       if (axis < -1) {
2913         return errors::InvalidArgument("axis should be at least -1, got ",
2914                                        axis);
2915       } else if (axis != -1) {
2916         ShapeHandle input;
2917         TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2918         DimensionHandle depth;
2919         TF_RETURN_IF_ERROR(
2920             c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2921       }
2922       c->set_output(0, c->input(0));
2923       return OkStatus();
2924     });
2925 
2926 REGISTER_OP("QuantizeAndDequantizeV4Grad")
2927     .Input("gradients: T")
2928     .Input("input: T")
2929     .Input("input_min: T")
2930     .Input("input_max: T")
2931     .Output("input_backprop: T")
2932     .Output("input_min_backprop: T")
2933     .Output("input_max_backprop: T")
2934     .Attr("T: {bfloat16, half, float, double}")
2935     .Attr("axis: int = -1")
__anon38bbb0e83e02(InferenceContext* c) 2936     .SetShapeFn([](InferenceContext* c) {
2937       int axis;
2938       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2939       const int minmax_rank = (axis == -1) ? 0 : 1;
2940       ShapeHandle minmax;
2941       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
2942       TF_RETURN_IF_ERROR(c->Merge(c->input(3), minmax, &minmax));
2943       if (axis < -1) {
2944         return errors::InvalidArgument("axis should be at least -1, got ",
2945                                        axis);
2946       } else if (axis != -1) {
2947         ShapeHandle input;
2948         TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2949         DimensionHandle depth;
2950         TF_RETURN_IF_ERROR(
2951             c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2952       }
2953       ShapeHandle inputs;
2954       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
2955       c->set_output(0, inputs);
2956       c->set_output(1, minmax);
2957       c->set_output(2, minmax);
2958       return OkStatus();
2959     });
2960 
2961 REGISTER_OP("QuantizeAndDequantizeV3")
2962     .Input("input: T")
2963     .Input("input_min: T")
2964     .Input("input_max: T")
2965     .Input("num_bits: int32")
2966     .Attr("signed_input: bool = true")
2967     .Attr("range_given: bool = true")
2968     .Output("output: T")
2969     .Attr("T: {bfloat16, half, float, double}")
2970     .Attr("narrow_range: bool = false")
2971     .Attr("axis: int = -1")
__anon38bbb0e83f02(InferenceContext* c) 2972     .SetShapeFn([](InferenceContext* c) {
2973       int axis;
2974       TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
2975       const int minmax_rank = (axis == -1) ? 0 : 1;
2976       ShapeHandle minmax;
2977       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2978       TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
2979       if (axis < -1) {
2980         return errors::InvalidArgument("axis should be at least -1, got ",
2981                                        axis);
2982       } else if (axis != -1) {
2983         ShapeHandle input;
2984         TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2985         DimensionHandle depth;
2986         TF_RETURN_IF_ERROR(
2987             c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2988       }
2989       ShapeHandle unused;
2990       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2991       c->set_output(0, c->input(0));
2992       return OkStatus();
2993     });
2994 
2995 REGISTER_OP("QuantizeV2")
2996     .Input("input: float")
2997     .Input("min_range: float")
2998     .Input("max_range: float")
2999     .Output("output: T")
3000     .Output("output_min: float")
3001     .Output("output_max: float")
3002     .Attr("T: quantizedtype")
3003     .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
3004     .Attr(
3005         "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = "
3006         "'HALF_AWAY_FROM_ZERO'")
3007     .Attr("narrow_range: bool = false")
3008     .Attr("axis: int = -1")
3009     .Attr("ensure_minimum_range: float = 0.01")
3010     .SetShapeFn(shape_inference::QuantizeV2Shape);
3011 
3012 REGISTER_OP("Dequantize")
3013     .Input("input: T")
3014     .Input("min_range: float")
3015     .Input("max_range: float")
3016     .Output("output: dtype")
3017     .Attr("T: quantizedtype")
3018     .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
3019     .Attr("narrow_range: bool = false")
3020     .Attr("axis: int = -1")
3021     .Attr("dtype: {bfloat16, float} = DT_FLOAT")
__anon38bbb0e84002(InferenceContext* c) 3022     .SetShapeFn([](InferenceContext* c) {
3023       int axis = -1;
3024       Status s = c->GetAttr("axis", &axis);
3025       if (!s.ok() && s.code() != error::NOT_FOUND) {
3026         return s;
3027       }
3028       if (axis < -1) {
3029         return errors::InvalidArgument("axis should be at least -1, got ",
3030                                        axis);
3031       }
3032       auto input_dims = c->Rank(c->input(0));
3033       if (axis > input_dims) {
3034         return errors::InvalidArgument(
3035             "Axis must be less than input dimension(", input_dims, "), got ",
3036             axis);
3037       }
3038       const int minmax_rank = (axis == -1) ? 0 : 1;
3039       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3040       ShapeHandle minmax;
3041       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
3042       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
3043       if (axis != -1) {
3044         ShapeHandle input;
3045         if (axis >= kint32max) {
3046           // Check int32 max bound for a corner case to prevent integer flow
3047           // when input actually has kint32max rank and above bound check is not
3048           // triggered.
3049           return errors::InvalidArgument(
3050               "Axis cannot be >= kint32max value, got ", axis);
3051         }
3052         TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
3053         DimensionHandle depth;
3054         TF_RETURN_IF_ERROR(
3055             c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
3056       }
3057       return OkStatus();
3058     });
3059 
3060 REGISTER_OP("QuantizedConcat")
3061     .Input("concat_dim: int32")
3062     .Input("values: N * T")
3063     .Input("input_mins: N * float32")
3064     .Input("input_maxes: N * float32")
3065     .Output("output: T")
3066     .Output("output_min: float")
3067     .Output("output_max: float")
3068     .Attr("N: int >= 2")
3069     .Attr("T: type")
__anon38bbb0e84102(InferenceContext* c) 3070     .SetShapeFn([](InferenceContext* c) {
3071       const int n = (c->num_inputs() - 1) / 3;
3072       TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n));
3073       ShapeHandle unused;
3074       for (int i = n + 1; i < c->num_inputs(); ++i) {
3075         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
3076       }
3077       c->set_output(1, c->Scalar());
3078       c->set_output(2, c->Scalar());
3079       return OkStatus();
3080     });
3081 
3082 REGISTER_OP("QuantizedReshape")
3083     .Input("tensor: T")
3084     .Input("shape: Tshape")
3085     .Input("input_min: float")
3086     .Input("input_max: float")
3087     .Output("output: T")
3088     .Output("output_min: float")
3089     .Output("output_max: float")
3090     .Attr("T: type")
3091     .Attr("Tshape: {int32, int64} = DT_INT32")
__anon38bbb0e84202(InferenceContext* c) 3092     .SetShapeFn([](InferenceContext* c) {
3093       TF_RETURN_IF_ERROR(SetOutputShapeForReshape(c));
3094       ShapeHandle unused;
3095       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3096       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3097       c->set_output(1, c->Scalar());
3098       c->set_output(2, c->Scalar());
3099       return OkStatus();
3100     });
3101 
3102 REGISTER_OP("QuantizedInstanceNorm")
3103     .Input("x: T")
3104     .Input("x_min: float")
3105     .Input("x_max: float")
3106     .Output("y: T")
3107     .Output("y_min: float")
3108     .Output("y_max: float")
3109     .Attr("T: quantizedtype")
3110     .Attr("output_range_given: bool = false")
3111     .Attr("given_y_min: float = 0")
3112     .Attr("given_y_max: float = 0")
3113     .Attr("variance_epsilon: float = 1e-5")
3114     .Attr("min_separation: float = 1e-3")
__anon38bbb0e84302(shape_inference::InferenceContext* c) 3115     .SetShapeFn([](shape_inference::InferenceContext* c) {
3116       shape_inference::ShapeHandle unused;
3117       // x should be a rank 4 tensor.
3118       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &unused));
3119       // Assert x_min and x_max are scalars (rank 0).
3120       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3121       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3122       // y has the same shape as x.
3123       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3124       // y_min and y_max are scalars.
3125       c->set_output(1, c->Scalar());
3126       c->set_output(2, c->Scalar());
3127       return OkStatus();
3128     });
3129 
3130 namespace {
3131 
ScatterNdTensorShape(InferenceContext * c)3132 Status ScatterNdTensorShape(InferenceContext* c) {
3133   ShapeHandle output_shape;
3134   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape));
3135   ShapeHandle indices_shape;
3136   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
3137   ShapeHandle updates_shape;
3138   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 0, &updates_shape));
3139   return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape,
3140                                                output_shape);
3141 }
3142 
3143 }  // namespace
3144 
3145 REGISTER_OP("UpperBound")
3146     .Input("sorted_inputs: T")
3147     .Input("values: T")
3148     .Output("output: out_type")
3149     .Attr("T: type")
3150     .Attr("out_type: {int32, int64} = DT_INT32")
__anon38bbb0e84502(InferenceContext* c) 3151     .SetShapeFn([](InferenceContext* c) {
3152       ShapeHandle unused_shape;
3153       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
3154       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
3155       c->set_output(0, c->input(1));
3156       return OkStatus();
3157     });
3158 
3159 REGISTER_OP("LowerBound")
3160     .Input("sorted_inputs: T")
3161     .Input("values: T")
3162     .Output("output: out_type")
3163     .Attr("T: type")
3164     .Attr("out_type: {int32, int64} = DT_INT32")
__anon38bbb0e84602(InferenceContext* c) 3165     .SetShapeFn([](InferenceContext* c) {
3166       ShapeHandle unused_shape;
3167       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
3168       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
3169       c->set_output(0, c->input(1));
3170       return OkStatus();
3171     });
3172 
3173 REGISTER_OP("ScatterNd")
3174     .Input("indices: Tindices")
3175     .Input("updates: T")
3176     .Input("shape: Tindices")
3177     .Output("output: T")
3178     .Attr("T: type")
3179     .Attr("Tindices: {int16, int32, int64}")
__anon38bbb0e84702(InferenceContext* c) 3180     .SetShapeFn([](InferenceContext* c) {
3181       ShapeHandle indices_shape;
3182       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &indices_shape));
3183       ShapeHandle updates_shape;
3184       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape));
3185       ShapeHandle output_shape;
3186       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output_shape));
3187       return shape_inference::ScatterNdShapeHelper(c, indices_shape,
3188                                                    updates_shape, output_shape);
3189     });
3190 
3191 REGISTER_OP("TensorScatterUpdate")
3192     .Input("tensor: T")
3193     .Input("indices: Tindices")
3194     .Input("updates: T")
3195     .Output("output: T")
3196     .Attr("T: type")
3197     .Attr("Tindices: {int16, int32, int64, uint16}")
3198     .SetShapeFn(ScatterNdTensorShape);
3199 
3200 REGISTER_OP("TensorScatterAdd")
3201     .Input("tensor: T")
3202     .Input("indices: Tindices")
3203     .Input("updates: T")
3204     .Output("output: T")
3205     .Attr("T: type")
3206     .Attr("Tindices: {int32, int64}")
3207     .SetShapeFn(ScatterNdTensorShape);
3208 
3209 REGISTER_OP("TensorScatterSub")
3210     .Input("tensor: T")
3211     .Input("indices: Tindices")
3212     .Input("updates: T")
3213     .Output("output: T")
3214     .Attr("T: type")
3215     .Attr("Tindices: {int32, int64}")
3216     .SetShapeFn(ScatterNdTensorShape);
3217 
3218 REGISTER_OP("TensorScatterMin")
3219     .Input("tensor: T")
3220     .Input("indices: Tindices")
3221     .Input("updates: T")
3222     .Output("output: T")
3223     .Attr("T: type")
3224     .Attr("Tindices: {int32, int64}")
3225     .SetShapeFn(ScatterNdTensorShape);
3226 
3227 REGISTER_OP("TensorScatterMax")
3228     .Input("tensor: T")
3229     .Input("indices: Tindices")
3230     .Input("updates: T")
3231     .Output("output: T")
3232     .Attr("T: type")
3233     .Attr("Tindices: {int32, int64}")
3234     .SetShapeFn(ScatterNdTensorShape);
3235 
3236 REGISTER_OP("ScatterNdNonAliasingAdd")
3237     .Input("input: T")
3238     .Input("indices: Tindices")
3239     .Input("updates: T")
3240     .Output("output: T")
3241     .Attr("T: {numbertype, bool}")
3242     .Attr("Tindices: {int32, int64}")
3243     .SetShapeFn(ScatterNdTensorShape);
3244 
3245 REGISTER_OP("FakeQuantWithMinMaxArgs")
3246     .Attr("min: float = -6.0")
3247     .Attr("max: float = 6.0")
3248     .Attr("num_bits: int = 8")
3249     .Attr("narrow_range: bool = false")
3250     .Input("inputs: float")
3251     .Output("outputs: float")
3252     .SetShapeFn(shape_inference::UnchangedShape);
3253 
3254 REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
3255     .Attr("min: float = -6.0")
3256     .Attr("max: float = 6.0")
3257     .Attr("num_bits: int = 8")
3258     .Attr("narrow_range: bool = false")
3259     .Input("gradients: float")
3260     .Input("inputs: float")
3261     .Output("backprops: float")
3262     .SetShapeFn(shape_inference::UnchangedShape);
3263 
3264 REGISTER_OP("FakeQuantWithMinMaxVars")
3265     .Attr("num_bits: int = 8")
3266     .Attr("narrow_range: bool = false")
3267     .Input("inputs: float")
3268     .Input("min: float")
3269     .Input("max: float")
3270     .Output("outputs: float")
__anon38bbb0e84802(InferenceContext* c) 3271     .SetShapeFn([](InferenceContext* c) {
3272       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
3273       ShapeHandle unused;
3274       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3275       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3276       return OkStatus();
3277     });
3278 
3279 REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
3280     .Attr("num_bits: int = 8")
3281     .Attr("narrow_range: bool = false")
3282     .Input("gradients: float")
3283     .Input("inputs: float")
3284     .Input("min: float")
3285     .Input("max: float")
3286     .Output("backprops_wrt_input: float")
3287     .Output("backprop_wrt_min: float")
3288     .Output("backprop_wrt_max: float")
__anon38bbb0e84902(InferenceContext* c) 3289     .SetShapeFn([](InferenceContext* c) {
3290       // gradients and inputs are same size.
3291       ShapeHandle inputs;
3292       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
3293 
3294       // min and max are scalars
3295       ShapeHandle min_max;
3296       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max));
3297       TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max));
3298 
3299       c->set_output(0, inputs);
3300       c->set_output(1, min_max);
3301       c->set_output(2, min_max);
3302       return OkStatus();
3303     });
3304 
3305 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
3306     .Attr("num_bits: int = 8")
3307     .Attr("narrow_range: bool = false")
3308     .Input("inputs: float")
3309     .Input("min: float")
3310     .Input("max: float")
3311     .Output("outputs: float")
__anon38bbb0e84a02(InferenceContext* c) 3312     .SetShapeFn([](InferenceContext* c) {
3313       ShapeHandle input, min, max;
3314       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
3315       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &min));
3316       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max));
3317 
3318       DimensionHandle unused;
3319       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(min, 0), &unused));
3320       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(max, 0), &unused));
3321       TF_RETURN_IF_ERROR(c->Merge(c->Dim(min, 0), c->Dim(max, 0), &unused));
3322 
3323       c->set_output(0, input);
3324       return OkStatus();
3325     });
3326 
3327 REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
3328     .Attr("num_bits: int = 8")
3329     .Attr("narrow_range: bool = false")
3330     .Input("gradients: float")
3331     .Input("inputs: float")
3332     .Input("min: float")
3333     .Input("max: float")
3334     .Output("backprops_wrt_input: float")
3335     .Output("backprop_wrt_min: float")
3336     .Output("backprop_wrt_max: float")
__anon38bbb0e84b02(InferenceContext* c) 3337     .SetShapeFn([](InferenceContext* c) {
3338       ShapeHandle inputs;
3339       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs));
3340       TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs));
3341       TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs));
3342 
3343       ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1));
3344 
3345       ShapeHandle min_max;
3346       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max));
3347       TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max));
3348       TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max));
3349 
3350       c->set_output(0, inputs);
3351       c->set_output(1, min_max);
3352       c->set_output(2, min_max);
3353       return OkStatus();
3354     });
3355 
3356 REGISTER_OP("Fingerprint")
3357     .Input("data: T")
3358     .Input("method: string")
3359     .Output("fingerprint: uint8")
3360     .Attr("T: type")
__anon38bbb0e84c02(InferenceContext* c) 3361     .SetShapeFn([](InferenceContext* c) {
3362       ShapeHandle unused;
3363       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused));
3364       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
3365 
3366       DimensionHandle fingerprint_size;
3367       const Tensor* method = c->input_tensor(1);
3368       if (method == nullptr) {
3369         fingerprint_size = c->UnknownDim();
3370       } else {
3371         if (method->dims() != 0) {
3372           return errors::InvalidArgument("`method` must be rank 0: ",
3373                                          method->shape());
3374         }
3375         const string& method_string = method->scalar<tstring>()();
3376         if (method_string != "farmhash64") {
3377           return errors::InvalidArgument("Unsupported method: ", method_string);
3378         }
3379         fingerprint_size = c->MakeDim(sizeof(uint64));
3380       }
3381 
3382       DimensionHandle batch = c->Dim(c->input(0), 0);
3383       c->set_output(0, c->MakeShape({batch, fingerprint_size}));
3384       return OkStatus();
3385     });
3386 
3387 #ifdef INTEL_MKL
3388 REGISTER_OP("_MklConcat")
3389     .Input("concat_dim: int32")
3390     .Input("values: N * T")
3391     .Input("mkl_concat_dim: uint8")
3392     .Input("mkl_values: N * uint8")
3393     .Output("output: T")
3394     .Output("mkl_output: uint8")
3395     .Attr("N: int >= 2")
3396     .Attr("T: type")
__anon38bbb0e84d02(InferenceContext* c) 3397     .SetShapeFn([](InferenceContext* c) {
3398       return shape_inference::ConcatShape(c, c->num_inputs() - 3);
3399     })
3400     .Doc(R"doc(
3401 MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation.
3402 
3403 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
3404 expected to invoke these operators.
3405 )doc");
3406 #endif
3407 
3408 // Deprecated op registrations:
3409 
3410 // The following can be deleted after 10mar2017.
3411 REGISTER_OP("BatchMatrixDiag")
3412     .Input("diagonal: T")
3413     .Output("output: T")
3414     .Attr("T: type")
3415     .Deprecated(14, "Use MatrixDiag")
3416     .SetShapeFn(shape_inference::UnknownShape);
3417 REGISTER_OP("BatchMatrixSetDiag")
3418     .Input("input: T")
3419     .Input("diagonal: T")
3420     .Output("output: T")
3421     .Attr("T: type")
3422     .Deprecated(14, "Use MatrixSetDiag")
3423     .SetShapeFn(shape_inference::UnknownShape);
3424 REGISTER_OP("BatchMatrixDiagPart")
3425     .Input("input: T")
3426     .Output("diagonal: T")
3427     .Attr("T: type")
3428     .Deprecated(14, "Use MatrixDiagPart")
3429     .SetShapeFn(shape_inference::UnknownShape);
3430 REGISTER_OP("BatchMatrixBandPart")
3431     .Input("input: T")
3432     .Input("num_lower: int64")
3433     .Input("num_upper: int64")
3434     .Output("band: T")
3435     .Attr("T: type")
3436     .Deprecated(14, "Use MatrixBandPart")
3437     .SetShapeFn(shape_inference::UnknownShape);
3438 
3439 }  // namespace tensorflow
3440