xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/nn_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cmath>
18 
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/kernel_shape_util.h"
21 #include "tensorflow/core/framework/numeric_op.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/shape_inference.h"
24 #include "tensorflow/core/lib/core/bits.h"
25 #include "tensorflow/core/lib/math/math_util.h"
26 #include "tensorflow/core/util/mirror_pad_mode.h"
27 #include "tensorflow/core/util/padding.h"
28 #include "tensorflow/core/util/tensor_format.h"
29 
30 namespace tensorflow {
31 
32 using shape_inference::DimensionHandle;
33 using shape_inference::InferenceContext;
34 using shape_inference::ShapeHandle;
35 
36 namespace {
37 
FractionalPoolShapeFn(InferenceContext * c)38 Status FractionalPoolShapeFn(InferenceContext* c) {
39   ShapeHandle input;
40   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
41 
42   std::vector<float> pooling_ratio;
43   TF_RETURN_IF_ERROR(c->GetAttr("pooling_ratio", &pooling_ratio));
44   if (pooling_ratio.size() != 4) {
45     return errors::InvalidArgument(
46         "pooling_ratio field must specify 4 dimensions");
47   }
48   std::vector<DimensionHandle> output_dims;
49   for (int i = 0; i < 4; ++i) {
50     DimensionHandle d = c->Dim(input, i);
51     if (c->ValueKnown(d)) {
52       // This must match the same logic in the kernel function in
53       // core/kernels/fractional_max_pool_op.cc.
54       auto val =
55           static_cast<int64_t>(std::floor(c->Value(d) / pooling_ratio[i]));
56       if (val < 0) {
57         return errors::InvalidArgument("Size computed for dim ", i,
58                                        " is negative: ", val);
59       }
60       output_dims.push_back(c->MakeDim(val));
61     } else {
62       output_dims.push_back(c->UnknownDim());
63     }
64   }
65 
66   c->set_output(0, c->MakeShape(output_dims));
67   c->set_output(1, c->Vector(output_dims[1]));
68   c->set_output(2, c->Vector(output_dims[2]));
69   return OkStatus();
70 }
71 
72 }  // namespace
73 
74 // --------------------------------------------------------------------------
75 
76 REGISTER_OP("AvgPool")
77     .Input("value: T")
78     .Output("output: T")
79     .Attr("ksize: list(int) >= 4")
80     .Attr("strides: list(int) >= 4")
81     .Attr(GetPaddingAttrString())
82     .Attr(GetConvnetDataFormatAttrString())
83     .Attr("T: {half, bfloat16, float, double}")
84     .SetShapeFn(shape_inference::AvgPoolShape);
85 
86 REGISTER_OP("AvgPoolGrad")
87     .Input("orig_input_shape: int32")
88     .Input("grad: T")
89     .Output("output: T")
90     .Attr("ksize: list(int) >= 4")
91     .Attr("strides: list(int) >= 4")
92     .Attr(GetPaddingAttrString())
93     .Attr(GetConvnetDataFormatAttrString())
94     .Attr("T: {half, bfloat16, float, double}")
95     .SetShapeFn(shape_inference::AvgPoolGradShape);
96 
97 // --------------------------------------------------------------------------
98 
99 REGISTER_OP("BatchNormWithGlobalNormalization")
100     .Input("t: T")
101     .Input("m: T")
102     .Input("v: T")
103     .Input("beta: T")
104     .Input("gamma: T")
105     .Output("result: T")
106     .Attr("T: numbertype")
107     .Attr("variance_epsilon: float")
108     .Attr("scale_after_normalization: bool")
109     .Deprecated(9, "Use tf.nn.batch_normalization()")
__anon84cbdd650202(InferenceContext* c) 110     .SetShapeFn([](InferenceContext* c) {
111       ShapeHandle input;
112       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
113 
114       DimensionHandle last_dim = c->Dim(input, 3);
115       for (int i = 1; i < 5; ++i) {  // covers m, v, beta, gamma
116         ShapeHandle vec;
117         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
118         TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
119       }
120 
121       ShapeHandle out;
122       TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
123       c->set_output(0, out);
124       return OkStatus();
125     });
126 
127 REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
128     .Input("t: T")
129     .Input("m: T")
130     .Input("v: T")
131     .Input("gamma: T")
132     .Input("backprop: T")
133     .Output("dx: T")
134     .Output("dm: T")
135     .Output("dv: T")
136     .Output("db: T")
137     .Output("dg: T")
138     .Attr("T: numbertype")
139     .Attr("variance_epsilon: float")
140     .Attr("scale_after_normalization: bool")
141     .Deprecated(9, "Use tf.nn.batch_normalization()")
__anon84cbdd650302(InferenceContext* c) 142     .SetShapeFn([](InferenceContext* c) {
143       ShapeHandle input;
144       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
145       TF_RETURN_IF_ERROR(
146           c->Merge(input, c->input(4), &input));  // with backprop
147 
148       DimensionHandle last_dim = c->Dim(input, 3);
149       for (int i = 1; i < 4; ++i) {  // covers m, v, gamma
150         ShapeHandle vec;
151         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
152         TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
153       }
154 
155       ShapeHandle dx;
156       TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx));
157       c->set_output(0, dx);
158 
159       ShapeHandle vector_shape = c->Vector(last_dim);
160       c->set_output(1, vector_shape);
161       c->set_output(2, vector_shape);
162       c->set_output(3, vector_shape);
163       c->set_output(4, vector_shape);
164       return OkStatus();
165     });
166 
167 // --------------------------------------------------------------------------
168 
169 REGISTER_OP("FusedBatchNorm")
170     .Input("x: T")
171     .Input("scale: T")
172     .Input("offset: T")
173     .Input("mean: T")
174     .Input("variance: T")
175     .Output("y: T")
176     .Output("batch_mean: T")
177     .Output("batch_variance: T")
178     .Output("reserve_space_1: T")
179     .Output("reserve_space_2: T")
180     .Attr("T: {float}")
181     .Attr("epsilon: float = 0.0001")
182     .Attr("exponential_avg_factor: float = 1.0")
183     .Attr(GetConvnetDataFormatAttrString())
184     .Attr("is_training: bool = true")
185     .SetShapeFn(shape_inference::FusedBatchNormShape);
186 
187 REGISTER_OP("FusedBatchNormV2")
188     .Input("x: T")
189     .Input("scale: U")
190     .Input("offset: U")
191     .Input("mean: U")
192     .Input("variance: U")
193     .Output("y: T")
194     .Output("batch_mean: U")
195     .Output("batch_variance: U")
196     .Output("reserve_space_1: U")
197     .Output("reserve_space_2: U")
198     .Attr("T: {half, bfloat16, float}")
199     .Attr("U: {float}")
200     .Attr("epsilon: float = 0.0001")
201     .Attr("exponential_avg_factor: float = 1.0")
202     .Attr(GetConvnetDataFormatAttrString())
203     .Attr("is_training: bool = true")
204     .SetShapeFn(shape_inference::FusedBatchNormShape);
205 
206 REGISTER_OP("FusedBatchNormV3")
207     .Input("x: T")
208     .Input("scale: U")
209     .Input("offset: U")
210     .Input("mean: U")
211     .Input("variance: U")
212     .Output("y: T")
213     .Output("batch_mean: U")
214     .Output("batch_variance: U")
215     .Output("reserve_space_1: U")
216     .Output("reserve_space_2: U")
217     .Output("reserve_space_3: U")
218     .Attr("T: {half, bfloat16, float}")
219     .Attr("U: {bfloat16, float}")
220     .Attr("epsilon: float = 0.0001")
221     .Attr("exponential_avg_factor: float = 1.0")
222     .Attr(GetConvnetDataFormat2D3DAttrString())
223     .Attr("is_training: bool = true")
224     .SetShapeFn(shape_inference::FusedBatchNormV3Shape);
225 
226 REGISTER_OP("_FusedBatchNormEx")
227     .Input("x: T")
228     .Input("scale: U")
229     .Input("offset: U")
230     .Input("mean: U")
231     .Input("variance: U")
232     .Input("side_input: num_side_inputs * T")
233     .Output("y: T")
234     .Output("batch_mean: U")
235     .Output("batch_variance: U")
236     .Output("reserve_space_1: U")
237     .Output("reserve_space_2: U")
238     .Output("reserve_space_3: U")
239     .Attr("T: {half, float, bfloat16}")
240     .Attr("U: {float}")
241     .Attr("epsilon: float = 0.0001")
242     .Attr("exponential_avg_factor: float = 1.0")
243     .Attr("num_side_inputs: int >= 0 = 0")
244     .Attr("activation_mode: string = \"Identity\"")
245     .Attr(GetConvnetDataFormatAttrString())
246     .Attr("is_training: bool = true")
247     .SetShapeFn(shape_inference::FusedBatchNormExShape)
248     .Doc(R"doc(
249 Internal FusedBatchNorm operation: reserved for internal use.
250 
251 Do not invoke this operator directly in Python. A fusion optimization is
252 expected to create these operators.
253 )doc");
254 
255 REGISTER_OP("FusedBatchNormGrad")
256     .Input("y_backprop: T")
257     .Input("x: T")
258     .Input("scale: T")
259     .Input("reserve_space_1: T")
260     .Input("reserve_space_2: T")
261     .Output("x_backprop: T")
262     .Output("scale_backprop: T")
263     .Output("offset_backprop: T")
264     .Output("reserve_space_3: T")
265     .Output("reserve_space_4: T")
266     .Attr("T: {float}")
267     .Attr("epsilon: float = 0.0001")
268     .Attr(GetConvnetDataFormatAttrString())
269     .Attr("is_training: bool = true")
270     .SetShapeFn(shape_inference::FusedBatchNormGradShape);
271 
272 REGISTER_OP("FusedBatchNormGradV2")
273     .Input("y_backprop: T")
274     .Input("x: T")
275     .Input("scale: float")
276     .Input("reserve_space_1: U")
277     .Input("reserve_space_2: U")
278     .Output("x_backprop: T")
279     .Output("scale_backprop: U")
280     .Output("offset_backprop: U")
281     .Output("reserve_space_3: U")
282     .Output("reserve_space_4: U")
283     .Attr("T: {half, bfloat16, float}")
284     .Attr("U: {float}")
285     .Attr("epsilon: float = 0.0001")
286     .Attr(GetConvnetDataFormatAttrString())
287     .Attr("is_training: bool = true")
288     .SetShapeFn(shape_inference::FusedBatchNormGradShape);
289 
290 REGISTER_OP("FusedBatchNormGradV3")
291     .Input("y_backprop: T")
292     .Input("x: T")
293     .Input("scale: float")
294     .Input("reserve_space_1: U")
295     .Input("reserve_space_2: U")
296     .Input("reserve_space_3: U")
297     .Output("x_backprop: T")
298     .Output("scale_backprop: U")
299     .Output("offset_backprop: U")
300     .Output("reserve_space_4: U")
301     .Output("reserve_space_5: U")
302     .Attr("T: {half, bfloat16, float}")
303     .Attr("U: {float}")
304     .Attr("epsilon: float = 0.0001")
305     .Attr(GetConvnetDataFormat2D3DAttrString())
306     .Attr("is_training: bool = true")
307     .SetShapeFn(shape_inference::FusedBatchNormGradShape);
308 
309 REGISTER_OP("_FusedBatchNormGradEx")
310     .Input("y_backprop: T")
311     .Input("x: T")
312     .Input("scale: float")
313     .Input("reserve_space_1: U")
314     .Input("reserve_space_2: U")
315     .Input("reserve_space_3: U")
316     .Input("offset: float")
317     .Input("y: T")
318     .Output("x_backprop: T")
319     .Output("scale_backprop: U")
320     .Output("offset_backprop: U")
321     .Output("reserve_space_4: U")
322     .Output("reserve_space_5: U")
323     .Output("side_input_backprop: num_side_inputs * T")
324     .Attr("T: {half, float}")
325     .Attr("U: {float}")
326     .Attr("epsilon: float = 0.0001")
327     .Attr("num_side_inputs: int >= 0 = 0")
328     .Attr("activation_mode: string = \"Identity\"")
329     .Attr(GetConvnetDataFormat2D3DAttrString())
330     .Attr("is_training: bool = true")
331     .SetShapeFn(shape_inference::FusedBatchNormGradExShape)
332     .Doc(R"doc(
333 Internal FusedBatchNormGrad operation: reserved for internal use.
334 
335 Do not invoke this operator directly in Python. A fusion optimization is
336 expected to create these operators.
337 )doc");
338 // --------------------------------------------------------------------------
339 
340 REGISTER_OP("BiasAdd")
341     .Attr("T: numbertype")
342     .Input("value: T")
343     .Input("bias: T")
344     .Attr(GetConvnetDataFormatAttrString())
345     .Output("output: T")
346     .SetShapeFn(shape_inference::BiasAddShape);
347 // --------------------------------------------------------------------------
348 
349 REGISTER_OP("BiasAddGrad")
350     .Attr("T: numbertype")
351     .Input("out_backprop: T")
352     .Attr(GetConvnetDataFormatAttrString())
353     .Output("output: T")
354     .SetShapeFn(shape_inference::BiasAddGradShape);
355 // --------------------------------------------------------------------------
356 
357 REGISTER_OP("BiasAddV1")
358     .Attr("T: numbertype")
359     .Input("value: T")
360     .Input("bias: T")
361     .Output("output: T")
362     .SetShapeFn(shape_inference::BiasAddShape);
363 // --------------------------------------------------------------------------
364 
365 REGISTER_OP("Conv2D")
366     .Input("input: T")
367     .Input("filter: T")
368     .Output("output: T")
369     .Attr("T: {half, bfloat16, float, double, int32}")
370     .Attr("strides: list(int)")
371     .Attr("use_cudnn_on_gpu: bool = true")
372     .Attr(GetPaddingAttrStringWithExplicit())
373     .Attr(GetExplicitPaddingsAttrString())
374     .Attr(GetConvnetDataFormatAttrString())
375     .Attr("dilations: list(int) = [1, 1, 1, 1]")
376     .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding);
377 
378 REGISTER_OP("Conv2DBackpropInput")
379     .Input("input_sizes: int32")
380     .Input("filter: T")
381     .Input("out_backprop: T")
382     .Output("output: T")
383     .Attr("T: {half, bfloat16, float, double, int32}")
384     .Attr("strides: list(int)")
385     .Attr("use_cudnn_on_gpu: bool = true")
386     .Attr(GetPaddingAttrStringWithExplicit())
387     .Attr(GetExplicitPaddingsAttrString())
388     .Attr(GetConvnetDataFormatAttrString())
389     .Attr("dilations: list(int) = [1, 1, 1, 1]")
390     .SetShapeFn(shape_inference::Conv2DBackpropInputShape);
391 
392 // TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
393 // more general string attribute ('kernel_impl'?) that can be used to
394 // select among several possible implementations.
395 REGISTER_OP("Conv2DBackpropFilter")
396     .Input("input: T")
397     .Input("filter_sizes: int32")
398     .Input("out_backprop: T")
399     .Output("output: T")
400     .Attr("T: {half, bfloat16, float, double}")
401     .Attr("strides: list(int)")
402     .Attr("use_cudnn_on_gpu: bool = true")
403     .Attr(GetPaddingAttrStringWithExplicit())
404     .Attr(GetExplicitPaddingsAttrString())
405     .Attr(GetConvnetDataFormatAttrString())
406     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd650402(InferenceContext* c) 407     .SetShapeFn([](InferenceContext* c) {
408       ShapeHandle s;
409       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
410       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
411       c->set_output(0, s);
412       return OkStatus();
413     });
414 
415 REGISTER_OP("_FusedConv2D")
416     .Input("input: T")
417     .Input("filter: T")
418     .Input("args: num_args * T")
419     .Output("output: T")
420     .Attr("T: {half, float, double}")
421     .Attr("num_args: int >= 0")
422     .Attr("strides: list(int)")
423     .Attr(GetPaddingAttrStringWithExplicit())
424     .Attr(GetExplicitPaddingsAttrString())
425     .Attr(GetConvnetDataFormatAttrString())
426     .Attr("dilations: list(int) = [1, 1, 1, 1]")
427     .Attr("use_cudnn_on_gpu: bool = true")
428     .Attr("fused_ops: list(string) = []")
429     // Attributes for the FusedBatchNorm ------------------------------------ //
430     .Attr("epsilon: float = 0.0001")
431     // Attributes for the LeakyRelu ----------------------------------------- //
432     .Attr("leakyrelu_alpha: float = 0.2")
433     // ---------------------------------------------------------------------- //
434     .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
435     .Doc(R"doc(
436 Performs a convolution followed by a specified series of operations.
437 
438 The inputs to the convolution are `input` and `filter`. The series of operations
439 that follows is specified by the `fused_ops` attribute, which is a list of TF op
440 names specified as strings (e.g. "Relu"). They are performed in order, where the
441 (first) input to each op is the output of the preceding op. The first input and
442 the output of each fused_op must be of type T.
443 
444 Currently supported fused_op combinations are: [X] and [X,A], where X is one of
445 {"BiasAdd","FusedBatchNorm"} and A is one of {"Elu","Relu","Relu6"}.
446 
447 * The first input to op X is the Conv2D result, and the additional input(s) to X
448 are specified by `args`.
449 * If there is an op A specified, the output of op X is the input to op A, and op
450 A produces the _FusedConv2D output. Otherwise, op X produces the _FusedConv2D
451 output.
452 
453 *NOTE*: Do not invoke this operator directly in Python. Grappler is expected to
454 create these operators.
455 )doc");
456 
457 namespace {
458 
CommonFusedConvCalculations(InferenceContext * c,bool has_resize)459 Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) {
460   ShapeHandle input;
461   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
462 
463   ShapeHandle resized = input;
464   int paddings_index = 1;
465   int filter_index = 2;
466   if (has_resize) {
467     paddings_index = 2;
468     filter_index = 3;
469 
470     ShapeHandle unused_size;
471     TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused_size));
472 
473     const Tensor* size = c->input_tensor(1);
474     DimensionHandle new_height = c->UnknownDim();
475     DimensionHandle new_width = c->UnknownDim();
476     if (size != nullptr) {
477       new_height = c->MakeDim(size->flat<int32>()(0));
478       new_width = c->MakeDim(size->flat<int32>()(1));
479     }
480     TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 1, new_height, &resized));
481     TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 2, new_width, &resized));
482   }
483 
484   ShapeHandle paddings;
485   TF_RETURN_IF_ERROR(c->WithRank(c->input(paddings_index), 2, &paddings));
486   TF_RETURN_IF_ERROR(
487       c->WithRank(resized, c->Value(c->Dim(paddings, 0)), &resized));
488   TF_RETURN_IF_ERROR(
489       c->Merge(paddings, c->Matrix(c->Rank(resized), 2), &paddings));
490 
491   const Tensor* paddings_t = c->input_tensor(paddings_index);
492   ShapeHandle padded;
493   if (paddings_t != nullptr) {
494     std::vector<DimensionHandle> output_dims;
495     for (int i = 0; i < 4; ++i) {
496       DimensionHandle dim = c->Dim(resized, i);
497       int64_t p0 = static_cast<int64_t>(paddings_t->matrix<int32>()(i, 0));
498       int64_t p1 = static_cast<int64_t>(paddings_t->matrix<int32>()(i, 1));
499       if (p0 < 0 || p1 < 0) {
500         return errors::InvalidArgument("Paddings must be non-negative");
501       }
502 
503       TF_RETURN_IF_ERROR(c->Add(dim, p0 + p1, &dim));
504       output_dims.push_back(dim);
505     }
506     padded = c->MakeShape(output_dims);
507   } else {
508     padded = c->UnknownShapeOfRank(4);
509   }
510 
511   // Work out the convolution's effect with 'padded' as the input.
512   ShapeHandle filter;
513   TF_RETURN_IF_ERROR(c->WithRank(c->input(filter_index), 4, &filter));
514   std::vector<int32> strides;
515   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
516   if (strides.size() != 4) {
517     return errors::InvalidArgument(
518         "Operation requires the stride attribute to contain 4 values, but ",
519         "got: ", strides.size());
520   }
521 
522   int32_t stride_rows = strides[1];
523   int32_t stride_cols = strides[2];
524 
525   DimensionHandle batch_size_dim = c->Dim(padded, 0);
526   DimensionHandle in_rows_dim = c->Dim(padded, 1);
527   DimensionHandle in_cols_dim = c->Dim(padded, 2);
528   DimensionHandle filter_rows_dim = c->Dim(filter, 0);
529   DimensionHandle filter_cols_dim = c->Dim(filter, 1);
530   DimensionHandle output_depth_dim = c->Dim(filter, 3);
531 
532   DimensionHandle unused;
533   TF_RETURN_IF_ERROR(c->Merge(c->Dim(padded, 3), c->Dim(filter, 2), &unused));
534 
535   Padding padding;
536   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
537 
538   DimensionHandle output_rows, output_cols;
539   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
540       c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
541   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
542       c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
543 
544   ShapeHandle output_shape = c->MakeShape(
545       {batch_size_dim, output_rows, output_cols, output_depth_dim});
546   c->set_output(0, output_shape);
547   return OkStatus();
548 }
549 
550 }  // namespace
551 
552 REGISTER_OP("DataFormatDimMap")
553     .Input("x: T")
554     .Output("y: T")
555     .Attr("T: {int32, int64} = DT_INT32")
556     .Attr("src_format: string = 'NHWC'")
557     .Attr("dst_format: string = 'NCHW'")
558     .SetShapeFn(shape_inference::UnchangedShape);
559 
560 REGISTER_OP("DataFormatVecPermute")
561     .Input("x: T")
562     .Output("y: T")
563     .Attr("T: {int32, int64} = DT_INT32")
564     .Attr("src_format: string = 'NHWC'")
565     .Attr("dst_format: string = 'NCHW'")
566     .SetShapeFn(shape_inference::UnchangedShape);
567 
568 REGISTER_OP("FusedResizeAndPadConv2D")
569     .Input("input: T")
570     .Input("size: int32")
571     .Input("paddings: int32")
572     .Input("filter: T")
573     .Output("output: T")
574     .Attr("T: {half, float, double}")
575     .Attr("resize_align_corners: bool = false")
576     .Attr(GetMirrorPadModeAttrString())
577     .Attr("strides: list(int)")
578     .Attr(GetPaddingAttrString())
__anon84cbdd650602(InferenceContext* c) 579     .SetShapeFn([](InferenceContext* c) {
580       return CommonFusedConvCalculations(c, true /* has_resize */);
581     });
582 
583 REGISTER_OP("FusedPadConv2D")
584     .Input("input: T")
585     .Input("paddings: int32")
586     .Input("filter: T")
587     .Output("output: T")
588     .Attr("T: {half, float, double}")
589     .Attr(GetMirrorPadModeAttrString())
590     .Attr("strides: list(int)")
591     .Attr(GetPaddingAttrString())
__anon84cbdd650702(InferenceContext* c) 592     .SetShapeFn([](InferenceContext* c) {
593       return CommonFusedConvCalculations(c, false /* has_resize */);
594     });
595 
596 // --------------------------------------------------------------------------
597 
598 REGISTER_OP("DepthwiseConv2dNative")
599     .Input("input: T")
600     .Input("filter: T")
601     .Output("output: T")
602     .Attr("T: {half, bfloat16, float, double}")
603     .Attr("strides: list(int)")
604     .Attr(GetPaddingAttrStringWithExplicit())
605     .Attr(GetExplicitPaddingsAttrString())
606     .Attr(GetConvnetDataFormatAttrString())
607     .Attr("dilations: list(int) = [1, 1, 1, 1]")
608     .SetShapeFn(shape_inference::DepthwiseConv2DNativeShapeWithExplicitPadding);
609 
610 REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
611     .Input("input_sizes: int32")
612     .Input("filter: T")
613     .Input("out_backprop: T")
614     .Output("output: T")
615     .Attr("T: {half, bfloat16, float, double}")
616     .Attr("strides: list(int)")
617     .Attr(GetPaddingAttrStringWithExplicit())
618     .Attr(GetExplicitPaddingsAttrString())
619     .Attr(GetConvnetDataFormatAttrString())
620     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd650802(InferenceContext* c) 621     .SetShapeFn([](InferenceContext* c) {
622       ShapeHandle s;
623       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
624       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
625       c->set_output(0, s);
626       return OkStatus();
627     });
628 
629 REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
630     .Input("input: T")
631     .Input("filter_sizes: int32")
632     .Input("out_backprop: T")
633     .Output("output: T")
634     .Attr("T: {half, bfloat16, float, double}")
635     .Attr("strides: list(int)")
636     .Attr(GetPaddingAttrStringWithExplicit())
637     .Attr(GetExplicitPaddingsAttrString())
638     .Attr(GetConvnetDataFormatAttrString())
639     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd650902(InferenceContext* c) 640     .SetShapeFn([](InferenceContext* c) {
641       ShapeHandle s;
642       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
643       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
644       c->set_output(0, s);
645       return OkStatus();
646     });
647 
648 REGISTER_OP("_FusedDepthwiseConv2dNative")
649     .Input("input: T")
650     .Input("filter: T")
651     .Input("args: num_args * T")
652     .Output("output: T")
653     .Attr("T: {half, bfloat16, float, double}")
654     .Attr("num_args: int >= 0")
655     .Attr("strides: list(int)")
656     .Attr(GetPaddingAttrString())
657     .Attr(GetConvnetDataFormatAttrString())
658     .Attr("dilations: list(int) = [1, 1, 1, 1]")
659     .Attr("fused_ops: list(string) = []")
660     // Attributes for the FusedBatchNorm ------------------------------------ //
661     .Attr("epsilon: float = 0.0001")
662     // Attributes for the LeakyRelu ----------------------------------------- //
663     .Attr("leakyrelu_alpha: float = 0.2")
664     // ---------------------------------------------------------------------- //
665     .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
666 
667 // --------------------------------------------------------------------------
668 
669 REGISTER_OP("Conv3D")
670     .Input("input: T")
671     .Input("filter: T")
672     .Output("output: T")
673     .Attr("T: {half, bfloat16, float, double}")
674     .Attr("strides: list(int) >= 5")
675     .Attr(GetPaddingAttrString())
676     .Attr(GetConvnet3dDataFormatAttrString())
677     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
678     .SetShapeFn(shape_inference::Conv3DShape);
679 
680 REGISTER_OP("Conv3DBackpropInput")
681     .Input("input: T")
682     .Input("filter: T")
683     .Input("out_backprop: T")
684     .Output("output: T")
685     .Attr("T: {half, float, double}")
686     .Attr("strides: list(int) >= 5")
687     .Attr(GetPaddingAttrString())
688     .Deprecated(10, "Use Conv3DBackpropInputV2")
689     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon84cbdd650a02(InferenceContext* c) 690     .SetShapeFn([](InferenceContext* c) {
691       return UnchangedShapeWithRank(c, 5);
692     });
693 
694 REGISTER_OP("Conv3DBackpropFilter")
695     .Input("input: T")
696     .Input("filter: T")
697     .Input("out_backprop: T")
698     .Output("output: T")
699     .Attr("T: {half, float, double}")
700     .Attr("strides: list(int) >= 5")
701     .Attr(GetPaddingAttrString())
702     .Deprecated(10, "Use Conv3DBackpropFilterV2")
703     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon84cbdd650b02(InferenceContext* c) 704     .SetShapeFn([](InferenceContext* c) {
705       ShapeHandle out;
706       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &out));
707       c->set_output(0, out);
708       return OkStatus();
709     });
710 
711 REGISTER_OP("Conv3DBackpropInputV2")
712     .Input("input_sizes: Tshape")
713     .Input("filter: T")
714     .Input("out_backprop: T")
715     .Output("output: T")
716     .Attr("T: {half, bfloat16, float, double}")
717     .Attr("strides: list(int) >= 5")
718     .Attr(GetPaddingAttrString())
719     .Attr(GetConvnet3dDataFormatAttrString())
720     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
721     .Attr("Tshape: {int32, int64} = DT_INT32")
__anon84cbdd650c02(InferenceContext* c) 722     .SetShapeFn([](InferenceContext* c) {
723       ShapeHandle s;
724       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
725       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
726       c->set_output(0, s);
727       return OkStatus();
728     });
729 
730 REGISTER_OP("Conv3DBackpropFilterV2")
731     .Input("input: T")
732     .Input("filter_sizes: int32")
733     .Input("out_backprop: T")
734     .Output("output: T")
735     .Attr("T: {half, bfloat16, float, double}")
736     .Attr("strides: list(int) >= 5")
737     .Attr(GetPaddingAttrString())
738     .Attr(GetConvnet3dDataFormatAttrString())
739     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon84cbdd650d02(InferenceContext* c) 740     .SetShapeFn([](InferenceContext* c) {
741       ShapeHandle s;
742       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
743       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
744       c->set_output(0, s);
745       return OkStatus();
746     });
747 
748 // --------------------------------------------------------------------------
749 
750 REGISTER_OP("AvgPool3D")
751     .Input("input: T")
752     .Output("output: T")
753     .Attr("ksize: list(int) >= 5")
754     .Attr("strides: list(int) >= 5")
755     .Attr(GetPaddingAttrString())
756     .Attr(GetConvnet3dDataFormatAttrString())
757     .Attr("T: {half, bfloat16, float, double}")
758     .SetShapeFn(shape_inference::Pool3DShape);
759 
760 REGISTER_OP("AvgPool3DGrad")
761     .Input("orig_input_shape: int32")
762     .Input("grad: T")
763     .Output("output: T")
764     .Attr("ksize: list(int) >= 5")
765     .Attr("strides: list(int) >= 5")
766     .Attr(GetPaddingAttrString())
767     .Attr(GetConvnet3dDataFormatAttrString())
768     .Attr("T: {half, bfloat16, float, double}")
769     .SetShapeFn(shape_inference::AvgPool3DGradShape);
770 
771 // --------------------------------------------------------------------------
772 
773 REGISTER_OP("MaxPool3D")
774     .Input("input: T")
775     .Output("output: T")
776     .Attr("ksize: list(int) >= 5")
777     .Attr("strides: list(int) >= 5")
778     .Attr(GetPaddingAttrString())
779     .Attr(GetConvnet3dDataFormatAttrString())
780     .Attr("T: {half, bfloat16, float}")
781     .SetShapeFn(shape_inference::Pool3DShape);
782 
783 REGISTER_OP("MaxPool3DGrad")
784     .Input("orig_input: TInput")
785     .Input("orig_output: TInput")
786     .Input("grad: T")
787     .Output("output: T")
788     .Attr("ksize: list(int) >= 5")
789     .Attr("strides: list(int) >= 5")
790     .Attr(GetPaddingAttrString())
791     .Attr(GetConvnet3dDataFormatAttrString())
792     .Attr("T: {half, bfloat16, float} = DT_FLOAT")
793     .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
794     .SetShapeFn(shape_inference::MaxPool3DGradShape);
795 
796 REGISTER_OP("MaxPool3DGradGrad")
797     .Input("orig_input: T")
798     .Input("orig_output: T")
799     .Input("grad: T")
800     .Output("output: T")
801     .Attr("ksize: list(int) >= 5 ")
802     .Attr("strides: list(int) >= 5")
803     .Attr(GetPaddingAttrString())
804     .Attr(GetConvnet3dDataFormatAttrString())
805     .Attr("T: realnumbertype")
__anon84cbdd650e02(InferenceContext* c) 806     .SetShapeFn([](InferenceContext* c) {
807       TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c));
808       ShapeHandle unused;
809       // Validate 'orig_input' is the same shape as 'grad'
810       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
811       // Validate 'orig_output' is same shape as 'output'
812       TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
813       return OkStatus();
814     });
815 
816 // --------------------------------------------------------------------------
817 
818 REGISTER_OP("L2Loss")
819     .Input("t: T")
820     .Output("output: T")
821     .Attr("T: {half, bfloat16, float, double}")
822     .SetShapeFn(shape_inference::ScalarShape);
823 
824 // --------------------------------------------------------------------------
825 
826 REGISTER_OP("LRN")
827     .Input("input: T")
828     .Output("output: T")
829     .Attr("depth_radius: int = 5")
830     .Attr("bias: float = 1.0")
831     .Attr("alpha: float = 1.0")
832     .Attr("beta: float = 0.5")
833     .Attr("T: {half, bfloat16, float} = DT_FLOAT")
__anon84cbdd650f02(InferenceContext* c) 834     .SetShapeFn([](InferenceContext* c) {
835       return UnchangedShapeWithRank(c, 4);
836     });
837 
838 REGISTER_OP("LRNGrad")
839     .Input("input_grads: T")
840     .Input("input_image: T")
841     .Input("output_image: T")
842     .Output("output: T")
843     .Attr("depth_radius: int = 5")
844     .Attr("bias: float = 1.0")
845     .Attr("alpha: float = 1.0")
846     .Attr("beta: float = 0.5")
847     .Attr("T: {half, bfloat16, float} = DT_FLOAT")
__anon84cbdd651002(InferenceContext* c) 848     .SetShapeFn([](InferenceContext* c) {
849       ShapeHandle s;
850       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s));  // input_grads
851       TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s));     // input_image
852       TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s));     // output_image
853       c->set_output(0, s);
854       return OkStatus();
855     });
856 
857 // --------------------------------------------------------------------------
858 
859 REGISTER_OP("MaxPool")
860     .Attr(
861         "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
862         "uint16, qint8} = DT_FLOAT")
863     .Attr("ksize: list(int) >= 4")
864     .Attr("strides: list(int) >= 4")
865     .Attr(GetPaddingAttrStringWithExplicit())
866     .Attr(GetExplicitPaddingsAttrString())
867     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
868     .Input("input: T")
869     .Output("output: T")
870     .SetShapeFn(shape_inference::MaxPoolShapeWithExplicitPadding);
871 
872 REGISTER_OP("MaxPoolV2")
873     .Attr(
874         "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
875         "uint16, qint8} = DT_FLOAT")
876     .Attr(GetPaddingAttrString())
877     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
878     .Input("input: T")
879     .Input("ksize: int32")
880     .Input("strides: int32")
881     .Output("output: T")
__anon84cbdd651102(InferenceContext* c) 882     .SetShapeFn([](InferenceContext* c) {
883       TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3));
884       return OkStatus();
885     });
886 
887 REGISTER_OP("MaxPoolGrad")
888     .Attr("ksize: list(int) >= 4")
889     .Attr("strides: list(int) >= 4")
890     .Attr(GetPaddingAttrStringWithExplicit())
891     .Attr(GetExplicitPaddingsAttrString())
892     .Attr(GetConvnetDataFormatAttrString())
893     .Input("orig_input: T")
894     .Input("orig_output: T")
895     .Input("grad: T")
896     .Output("output: T")
897     .Attr("T: realnumbertype = DT_FLOAT")
898     .SetShapeFn(shape_inference::MaxPoolGradShape);
899 
900 REGISTER_OP("MaxPoolGradV2")
901     .Attr(GetPaddingAttrString())
902     .Attr(GetConvnetDataFormatAttrString())
903     .Input("orig_input: T")
904     .Input("orig_output: T")
905     .Input("grad: T")
906     .Input("ksize: int32")
907     .Input("strides: int32")
908     .Output("output: T")
909     .Attr("T: realnumbertype = DT_FLOAT")
910     .SetShapeFn(shape_inference::MaxPoolGradShape);
911 
912 // TODO(b/150813181): Implement explicit padding.
913 REGISTER_OP("MaxPoolGradGrad")
914     .Attr("ksize: list(int) >= 4")
915     .Attr("strides: list(int) >= 4")
916     .Attr(GetPaddingAttrString())
917     .Attr(GetConvnetDataFormatAttrString())
918     .Input("orig_input: T")
919     .Input("orig_output: T")
920     .Input("grad: T")
921     .Output("output: T")
922     .Attr("T: realnumbertype")
__anon84cbdd651202(InferenceContext* c) 923     .SetShapeFn([](InferenceContext* c) {
924       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
925       ShapeHandle unused;
926       // Validate 'orig_input' is the same shape as 'grad'
927       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
928       // Validate 'orig_output' is same shape as 'output'
929       TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
930       return OkStatus();
931     });
932 
933 REGISTER_OP("MaxPoolGradGradV2")
934     .Attr(GetPaddingAttrString())
935     .Attr(GetConvnetDataFormatAttrString())
936     .Input("orig_input: T")
937     .Input("orig_output: T")
938     .Input("grad: T")
939     .Input("ksize: int32")
940     .Input("strides: int32")
941     .Output("output: T")
942     .Attr("T: realnumbertype")
__anon84cbdd651302(InferenceContext* c) 943     .SetShapeFn([](InferenceContext* c) {
944       TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5));
945       ShapeHandle unused;
946       // Validate 'orig_input' is the same shape as 'grad'
947       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
948       // Validate 'orig_output' is same shape as 'output'
949       TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
950       return OkStatus();
951     });
952 
953 REGISTER_OP("MaxPoolWithArgmax")
954     .Attr("ksize: list(int) >= 4")
955     .Attr("strides: list(int) >= 4")
956     .Attr("Targmax: {int32, int64} = DT_INT64")
957     .Attr(GetPaddingAttrString())
958     .Attr("include_batch_in_index: bool = false")
959     .Input("input: T")
960     .Output("output: T")
961     .Output("argmax: Targmax")
962     .Attr("T: realnumbertype")
__anon84cbdd651402(InferenceContext* c) 963     .SetShapeFn([](InferenceContext* c) {
964       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
965       c->set_output(1, c->output(0));
966       return OkStatus();
967     });
968 
969 REGISTER_OP("MaxPoolGradWithArgmax")
970     .Attr("ksize: list(int) >= 4")
971     .Attr("strides: list(int) >= 4")
972     .Attr(GetPaddingAttrString())
973     .Attr("include_batch_in_index: bool = false")
974     .Attr("Targmax: {int32, int64}")
975     .Input("input: T")
976     .Input("grad: T")
977     .Input("argmax: Targmax")
978     .Output("output: T")
979     .Attr("T: realnumbertype")
__anon84cbdd651502(InferenceContext* c) 980     .SetShapeFn([](InferenceContext* c) {
981       return UnchangedShapeWithRank(c, 4);
982     });
983 
984 REGISTER_OP("MaxPoolGradGradWithArgmax")
985     .Attr("ksize: list(int) >= 4")
986     .Attr("strides: list(int) >= 4")
987     .Attr(GetPaddingAttrString())
988     .Attr("include_batch_in_index: bool = false")
989     .Attr("Targmax: {int32, int64}")
990     .Input("input: T")
991     .Input("grad: T")
992     .Input("argmax: Targmax")
993     .Output("output: T")
994     .Attr("T: realnumbertype")
__anon84cbdd651602(InferenceContext* c) 995     .SetShapeFn([](InferenceContext* c) {
996       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
997       ShapeHandle unused;
998       // Validate 'orig_input' is the same shape as 'grad'
999       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &unused));
1000       // Validate 'argmax' is same shape as 'output'
1001       TF_RETURN_IF_ERROR(c->Merge(c->input(2), c->output(0), &unused));
1002       return OkStatus();
1003     });
1004 
1005 // --------------------------------------------------------------------------
1006 
1007 REGISTER_OP("Dilation2D")
1008     .Input("input: T")
1009     .Input("filter: T")
1010     .Output("output: T")
1011     .Attr("T: realnumbertype")
1012     .Attr("strides: list(int) >= 4")
1013     .Attr("rates: list(int) >= 4")
1014     .Attr(GetPaddingAttrString())
__anon84cbdd651702(InferenceContext* c) 1015     .SetShapeFn([](InferenceContext* c) {
1016       ShapeHandle input_shape;
1017       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1018       ShapeHandle filter_shape;
1019       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &filter_shape));
1020 
1021       std::vector<int32> strides;
1022       TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1023       if (strides.size() != 4) {
1024         return errors::InvalidArgument(
1025             "Dilation2D requires the stride attribute to contain 4 values, but "
1026             "got: ",
1027             strides.size());
1028       }
1029 
1030       std::vector<int32> rates;
1031       TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
1032       if (rates.size() != 4) {
1033         return errors::InvalidArgument(
1034             "Dilation2D requires the rates attribute to contain 4 values, but "
1035             "got: ",
1036             rates.size());
1037       }
1038 
1039       int32_t stride_rows = strides[1];
1040       int32_t stride_cols = strides[2];
1041 
1042       int32_t rate_rows = rates[1];
1043       int32_t rate_cols = rates[2];
1044 
1045       DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1046       DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
1047       DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
1048       DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
1049       DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
1050       DimensionHandle output_depth_dim = c->Dim(filter_shape, 2);
1051 
1052       if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim) ||
1053           !c->ValueKnown(filter_rows_dim) || !c->ValueKnown(filter_cols_dim)) {
1054         ShapeHandle output_shape =
1055             c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
1056                           InferenceContext::kUnknownDim, output_depth_dim});
1057         c->set_output(0, output_shape);
1058         return OkStatus();
1059       }
1060       DimensionHandle unused;
1061       TF_RETURN_IF_ERROR(
1062           c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused));
1063 
1064       auto in_rows = c->Value(in_rows_dim);
1065       auto in_cols = c->Value(in_cols_dim);
1066       auto filter_rows = c->Value(filter_rows_dim);
1067       auto filter_cols = c->Value(filter_cols_dim);
1068       auto filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_rows - 1);
1069       auto filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_cols - 1);
1070 
1071       Padding padding;
1072       TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1073 
1074       int64_t output_rows, output_cols;
1075       int64_t padding_before, padding_after;
1076       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
1077           in_rows, filter_rows_eff, stride_rows, padding, &output_rows,
1078           &padding_before, &padding_after));
1079       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
1080           in_cols, filter_cols_eff, stride_cols, padding, &output_cols,
1081           &padding_before, &padding_after));
1082 
1083       ShapeHandle output_shape = c->MakeShape(
1084           {batch_size_dim, output_rows, output_cols, output_depth_dim});
1085       c->set_output(0, output_shape);
1086       return OkStatus();
1087     });
1088 
1089 REGISTER_OP("Dilation2DBackpropInput")
1090     .Input("input: T")
1091     .Input("filter: T")
1092     .Input("out_backprop: T")
1093     .Output("in_backprop: T")
1094     .Attr("T: realnumbertype")
1095     .Attr("strides: list(int) >= 4")
1096     .Attr("rates: list(int) >= 4")
1097     .Attr(GetPaddingAttrString())
1098     .SetShapeFn(shape_inference::UnchangedShape);
1099 
1100 REGISTER_OP("Dilation2DBackpropFilter")
1101     .Input("input: T")
1102     .Input("filter: T")
1103     .Input("out_backprop: T")
1104     .Output("filter_backprop: T")
1105     .Attr("T: realnumbertype")
1106     .Attr("strides: list(int) >= 4")
1107     .Attr("rates: list(int) >= 4")
1108     .Attr(GetPaddingAttrString())
__anon84cbdd651802(InferenceContext* c) 1109     .SetShapeFn([](InferenceContext* c) {
1110       c->set_output(0, c->input(1));
1111       return OkStatus();
1112     });
1113 
1114 // --------------------------------------------------------------------------
1115 
1116 REGISTER_OP("Relu")
1117     .Input("features: T")
1118     .Output("activations: T")
1119     .Attr("T: {realnumbertype, qint8}")
1120     .SetShapeFn(shape_inference::UnchangedShape);
1121 
1122 REGISTER_OP("ReluGrad")
1123     .Input("gradients: T")
1124     .Input("features: T")
1125     .Output("backprops: T")
1126     .Attr("T: realnumbertype")
1127     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1128 
1129 REGISTER_OP("Relu6")
1130     .Input("features: T")
1131     .Output("activations: T")
1132     .Attr("T: realnumbertype")
1133     .SetShapeFn(shape_inference::UnchangedShape);
1134 
1135 REGISTER_OP("Relu6Grad")
1136     .Input("gradients: T")
1137     .Input("features: T")
1138     .Output("backprops: T")
1139     .Attr("T: realnumbertype")
1140     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1141 
1142 REGISTER_OP("LeakyRelu")
1143     .Input("features: T")
1144     .Output("activations: T")
1145     .Attr("alpha: float = 0.2")
1146     .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
1147     .SetShapeFn(shape_inference::UnchangedShape);
1148 
1149 REGISTER_OP("LeakyReluGrad")
1150     .Input("gradients: T")
1151     .Input("features: T")
1152     .Output("backprops: T")
1153     .Attr("alpha: float = 0.2")
1154     .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
1155     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1156 
1157 REGISTER_OP("Elu")
1158     .Input("features: T")
1159     .Output("activations: T")
1160     .Attr("T: {half, bfloat16, float, double}")
1161     .SetShapeFn(shape_inference::UnchangedShape);
1162 
1163 REGISTER_OP("EluGrad")
1164     .Input("gradients: T")
1165     .Input("outputs: T")
1166     .Output("backprops: T")
1167     .Attr("T: {half, bfloat16, float, double}")
1168     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1169 
1170 REGISTER_OP("Selu")
1171     .Input("features: T")
1172     .Output("activations: T")
1173     .Attr("T: {half, bfloat16, float, double}")
1174     .SetShapeFn(shape_inference::UnchangedShape);
1175 
1176 REGISTER_OP("SeluGrad")
1177     .Input("gradients: T")
1178     .Input("outputs: T")
1179     .Output("backprops: T")
1180     .Attr("T: {half, bfloat16, float, double}")
1181     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1182 
1183 REGISTER_OP("Softplus")
1184     .Input("features: T")
1185     .Output("activations: T")
1186     .Attr("T: {half, bfloat16, float, double}")
1187     .SetShapeFn(shape_inference::UnchangedShape);
1188 
1189 REGISTER_OP("SoftplusGrad")
1190     .Input("gradients: T")
1191     .Input("features: T")
1192     .Output("backprops: T")
1193     .Attr("T: {half, bfloat16, float, double}")
1194     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1195 
1196 REGISTER_OP("Softsign")
1197     .Input("features: T")
1198     .Output("activations: T")
1199     .Attr("T: {half, bfloat16, float, double}")
1200     .SetShapeFn(shape_inference::UnchangedShape);
1201 
1202 REGISTER_OP("SoftsignGrad")
1203     .Input("gradients: T")
1204     .Input("features: T")
1205     .Output("backprops: T")
1206     .Attr("T: {half, bfloat16, float, double}")
1207     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1208 
1209 // --------------------------------------------------------------------------
1210 
1211 REGISTER_OP("Softmax")
1212     .Input("logits: T")
1213     .Output("softmax: T")
1214     .Attr("T: {half, bfloat16, float, double}")
__anon84cbdd651902(InferenceContext* c) 1215     .SetShapeFn([](InferenceContext* c) {
1216       return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
1217     });
1218 
1219 // --------------------------------------------------------------------------
1220 
1221 REGISTER_OP("LogSoftmax")
1222     .Input("logits: T")
1223     .Output("logsoftmax: T")
1224     .Attr("T: {half, bfloat16, float, double}")
__anon84cbdd651a02(InferenceContext* c) 1225     .SetShapeFn([](InferenceContext* c) {
1226       return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
1227     });
1228 
1229 // --------------------------------------------------------------------------
1230 
1231 REGISTER_OP("SoftmaxCrossEntropyWithLogits")
1232     .Input("features: T")
1233     .Input("labels: T")
1234     .Output("loss: T")
1235     .Output("backprop: T")
1236     .Attr("T: {half, bfloat16, float, double}")
__anon84cbdd651b02(InferenceContext* c) 1237     .SetShapeFn([](InferenceContext* c) {
1238       ShapeHandle input;
1239       if (c->WithRank(c->input(0), 2, &input) == OkStatus() &&
1240           c->Merge(input, c->input(1), &input) == OkStatus()) {
1241         DimensionHandle batch_size = c->Dim(input, 0);
1242         c->set_output(0, c->Vector(batch_size));
1243         c->set_output(1, input);
1244         return OkStatus();
1245       }
1246       TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFn(c, 1));
1247 
1248       if (!c->RankKnown(c->output(1))) {
1249         return errors::InvalidArgument(
1250             "Shape must be broadcasted with rank 2, but is rank is unknown.");
1251       }
1252 
1253       if (c->Rank(c->output(1)) != 2) {
1254         return errors::InvalidArgument(
1255             "Shape must be broadcasted with rank 2, but is rank ",
1256             c->Rank(c->output(1)));
1257       }
1258       DimensionHandle batch_size = c->Dim(c->output(1), 0);
1259       c->set_output(0, c->Vector(batch_size));
1260       return OkStatus();
1261     });
1262 
1263 REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits")
1264     .Input("features: T")
1265     .Input("labels: Tlabels")
1266     .Output("loss: T")
1267     .Output("backprop: T")
1268     .Attr("T: {half, bfloat16, float, double}")
1269     .Attr("Tlabels: {int32, int64} = DT_INT64")
__anon84cbdd651c02(InferenceContext* c) 1270     .SetShapeFn([](InferenceContext* c) {
1271       ShapeHandle features;
1272       ShapeHandle labels;
1273       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &features));
1274       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &labels));
1275 
1276       DimensionHandle batch_size;
1277       TF_RETURN_IF_ERROR(
1278           c->Merge(c->Dim(features, 0), c->Dim(labels, 0), &batch_size));
1279       TF_RETURN_IF_ERROR(c->ReplaceDim(features, 0, batch_size, &features));
1280 
1281       c->set_output(0, c->Vector(batch_size));
1282       c->set_output(1, features);
1283       return OkStatus();
1284     });
1285 
1286 // --------------------------------------------------------------------------
1287 
1288 REGISTER_OP("InTopK")
1289     .Input("predictions: float")
1290     .Input("targets: T")
1291     .Output("precision: bool")
1292     .Attr("k: int")
1293     .Attr("T: {int32, int64} = DT_INT32")
__anon84cbdd651d02(InferenceContext* c) 1294     .SetShapeFn([](InferenceContext* c) {
1295       ShapeHandle predictions;
1296       ShapeHandle targets;
1297       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
1298       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
1299       DimensionHandle batch_size;
1300       TF_RETURN_IF_ERROR(
1301           c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
1302       c->set_output(0, c->Vector(batch_size));
1303       return OkStatus();
1304     });
1305 
1306 // This is the same as `InTopK`, but takes `k` as in input rather than an attr.
1307 REGISTER_OP("InTopKV2")
1308     .Input("predictions: float")
1309     .Input("targets: T")
1310     .Input("k: T")
1311     .Output("precision: bool")
1312     .Attr("T: {int32, int64} = DT_INT32")
__anon84cbdd651e02(InferenceContext* c) 1313     .SetShapeFn([](InferenceContext* c) {
1314       ShapeHandle predictions;
1315       ShapeHandle targets;
1316       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
1317       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
1318       DimensionHandle batch_size;
1319       TF_RETURN_IF_ERROR(
1320           c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
1321       c->set_output(0, c->Vector(batch_size));
1322       return OkStatus();
1323     });
1324 
1325 namespace {
1326 
TopKShapeFn(InferenceContext * c)1327 Status TopKShapeFn(InferenceContext* c) {
1328   ShapeHandle input;
1329   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1330 
1331   // Get the k value, either from input tensor or attribute.
1332   DimensionHandle k_dim;
1333   if (c->num_inputs() >= 2) {
1334     TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &k_dim));
1335   } else {
1336     int32_t k;
1337     TF_RETURN_IF_ERROR(c->GetAttr("k", &k));
1338     if (k < 0) {
1339       return errors::InvalidArgument("Need k >= 0, got ", k);
1340     }
1341     k_dim = c->MakeDim(k);
1342   }
1343 
1344   DimensionHandle last_dim = c->Dim(input, -1);
1345   if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) &&
1346       c->Value(last_dim) < c->Value(k_dim)) {
1347     return errors::InvalidArgument(
1348         "input must have last dimension >= k = ", c->Value(k_dim), " but is ",
1349         c->Value(last_dim));
1350   }
1351 
1352   // Replace last_dim with k_dim.
1353   ShapeHandle s;
1354   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
1355   TF_RETURN_IF_ERROR(c->Concatenate(s, c->Vector(k_dim), &s));
1356   c->set_output(0, s);
1357   c->set_output(1, s);
1358   return OkStatus();
1359 }
1360 
1361 // Utility functions for ApproxTopKShape.
1362 // It is not easy to link xla/client/lib into the tensorflow core lib, so we
1363 // have to replicate the logic.
1364 // LINT.IfChange
log2_floor(uint64_t value)1365 inline uint32_t log2_floor(uint64_t value) {
1366   return value == 0 ? 0 : Log2Floor(value);
1367 }
1368 
log2_ceil(uint64_t value)1369 inline uint32_t log2_ceil(uint64_t value) {
1370   return value == 0 ? 0 : Log2Ceiling(value);
1371 }
1372 
ApproxTopKShape(shape_inference::InferenceContext * c)1373 Status ApproxTopKShape(shape_inference::InferenceContext* c) {
1374   int64_t k;
1375   int64_t reduction_dimension;
1376   float recall_target;
1377   int64_t reduction_input_size_override;
1378   bool aggregate_to_topk;
1379   TF_RETURN_IF_ERROR(c->GetAttr("k", &k));
1380   TF_RETURN_IF_ERROR(c->GetAttr("reduction_dimension", &reduction_dimension));
1381   TF_RETURN_IF_ERROR(c->GetAttr("recall_target", &recall_target));
1382   TF_RETURN_IF_ERROR(c->GetAttr("reduction_input_size_override",
1383                                 &reduction_input_size_override));
1384   TF_RETURN_IF_ERROR(c->GetAttr("aggregate_to_topk", &aggregate_to_topk));
1385   ShapeHandle input_shape;
1386   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
1387   if (reduction_dimension < 0) {
1388     // Reverse index
1389     reduction_dimension += c->Rank(input_shape);
1390   }
1391   int64_t reduction_dim_value =
1392       c->Value(c->Dim(input_shape, reduction_dimension));
1393 
1394   if (reduction_dim_value < k) {
1395     return errors::InvalidArgument("input must have last dimension >= k = ", k,
1396                                    " but was ", reduction_dim_value);
1397   }
1398 
1399   int64_t output_dim_value = [&] {
1400     if (aggregate_to_topk) {
1401       return k;
1402     }
1403     int64_t tpu_tiling = c->Rank(input_shape) == 1 ? 1024 : 128;
1404     if (reduction_dim_value <= tpu_tiling || recall_target == 1.0) {
1405       return reduction_dim_value;
1406     }
1407     if (k == 1) {
1408       return tpu_tiling;
1409     }
1410     uint64_t logical_input_size = reduction_input_size_override >= 0
1411                                       ? reduction_input_size_override
1412                                       : reduction_dim_value;
1413     uint64_t m = std::min<uint64_t>(
1414         std::max<uint64_t>(
1415             static_cast<uint64_t>((1.0 - k) /
1416                                   std::log(static_cast<double>(recall_target))),
1417             tpu_tiling),
1418         reduction_dim_value);
1419     uint32_t log2_reduction = log2_floor(logical_input_size / m);
1420     if (log2_reduction == 0) {
1421       return reduction_dim_value;
1422     }
1423     log2_reduction = std::min<uint32_t>(
1424         log2_reduction, log2_ceil(reduction_dim_value / tpu_tiling));
1425     return tensorflow::MathUtil::CeilOfRatio<int64_t>(
1426                tensorflow::MathUtil::CeilOfRatio<int64_t>(reduction_dim_value,
1427                                                           tpu_tiling),
1428                (1 << log2_reduction)) *
1429            tpu_tiling;
1430   }();
1431 
1432   auto output_dim = c->MakeDim(output_dim_value);
1433 
1434   ShapeHandle output_shape;
1435   TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, reduction_dimension, output_dim,
1436                                    &output_shape));
1437   c->set_output(0, output_shape);
1438   c->set_output(1, output_shape);
1439   return OkStatus();
1440 }
1441 // LINT.ThenChange(//tensorflow/compiler/xla/client/lib/approx_topk_shape.cc)
1442 
1443 }  // namespace
1444 
1445 REGISTER_OP("TopK")
1446     .Input("input: T")
1447     .Output("values: T")
1448     .Output("indices: int32")
1449     .Attr("k: int >= 0")
1450     .Attr("sorted: bool = true")
1451     .Attr("T: realnumbertype")
1452     .Deprecated(7, "Use TopKV2 instead")
1453     .SetShapeFn(TopKShapeFn);
1454 
1455 // This is the same as `TopK`, but takes `k` as in input rather than an attr.
1456 REGISTER_OP("TopKV2")
1457     .Input("input: T")
1458     .Input("k: int32")
1459     .Output("values: T")
1460     .Output("indices: int32")
1461     .Attr("sorted: bool = true")
1462     .Attr("T: realnumbertype")
1463     .SetShapeFn(TopKShapeFn);
1464 
1465 REGISTER_OP("ApproxTopK")
1466     .Input("input: T")
1467     .Output("values: T")
1468     .Output("indices: int32")
1469     .Attr("k: int >= 0")
1470     .Attr("reduction_dimension: int = -1")
1471     .Attr("recall_target: float = 0.95")
1472     .Attr("is_max_k: bool = true")
1473     .Attr("reduction_input_size_override: int = -1")
1474     .Attr("aggregate_to_topk: bool = true")
1475     .Attr("T: {half, bfloat16, float}")
1476     .SetShapeFn(ApproxTopKShape);
1477 
1478 // --------------------------------------------------------------------------
1479 
1480 REGISTER_OP("NthElement")
1481     .Input("input: T")
1482     .Input("n: int32")
1483     .Output("values: T")
1484     .Attr("reverse: bool = false")
1485     .Attr("T: realnumbertype")
__anon84cbdd652102(InferenceContext* c) 1486     .SetShapeFn([](InferenceContext* c) {
1487       ShapeHandle input;
1488       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1489 
1490       // Get the n value from input tensor, and make sure which is a scalar.
1491       DimensionHandle n_dim;
1492       TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &n_dim));
1493 
1494       // The last dimension of input tensor must be greater than N.
1495       DimensionHandle last_dim = c->Dim(input, -1);
1496       if (c->ValueKnown(last_dim) && c->ValueKnown(n_dim) &&
1497           c->Value(last_dim) <= c->Value(n_dim)) {
1498         return errors::InvalidArgument(
1499             "Input must have last dimension > n = ", c->Value(n_dim),
1500             " but is ", c->Value(last_dim));
1501       }
1502 
1503       // Reduce last_dim for output tensor
1504       ShapeHandle s;
1505       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
1506       c->set_output(0, s);
1507       return OkStatus();
1508     });
1509 
1510 // --------------------------------------------------------------------------
1511 
1512 REGISTER_OP("FractionalMaxPool")
1513     .Input("value: T")
1514     .Output("output: T")
1515     .Output("row_pooling_sequence: int64")
1516     .Output("col_pooling_sequence: int64")
1517     .Attr("pooling_ratio: list(float) >=4")
1518     .Attr("pseudo_random: bool = false")
1519     .Attr("overlapping: bool = false")
1520     .Attr("deterministic: bool = false")
1521     .Attr("seed: int = 0")
1522     .Attr("seed2: int = 0")
1523     .Attr("T: {float, double, int32, int64}")
1524     .SetShapeFn(FractionalPoolShapeFn);
1525 
1526 REGISTER_OP("FractionalMaxPoolGrad")
1527     .Input("orig_input: T")
1528     .Input("orig_output: T")
1529     .Input("out_backprop: T")
1530     .Input("row_pooling_sequence: int64")
1531     .Input("col_pooling_sequence: int64")
1532     .Output("output: T")
1533     .Attr("overlapping: bool = false")
1534     .Attr("T: {float, double, int32, int64}")
__anon84cbdd652202(InferenceContext* c) 1535     .SetShapeFn([](InferenceContext* c) {
1536       return shape_inference::UnchangedShapeWithRank(c, 4);
1537     });
1538 
1539 // --------------------------------------------------------------------------
1540 
1541 REGISTER_OP("FractionalAvgPool")
1542     .Input("value: T")
1543     .Output("output: T")
1544     .Output("row_pooling_sequence: int64")
1545     .Output("col_pooling_sequence: int64")
1546     .Attr("pooling_ratio: list(float) >=4")
1547     .Attr("pseudo_random: bool = false")
1548     .Attr("overlapping: bool = false")
1549     .Attr("deterministic: bool = false")
1550     .Attr("seed: int = 0")
1551     .Attr("seed2: int = 0")
1552     .Attr("T: {float, double, int32, int64}")
1553     .SetShapeFn(FractionalPoolShapeFn);
1554 
1555 REGISTER_OP("FractionalAvgPoolGrad")
1556     .Input("orig_input_tensor_shape: int64")
1557     .Input("out_backprop: T")
1558     .Input("row_pooling_sequence: int64")
1559     .Input("col_pooling_sequence: int64")
1560     .Output("output: T")
1561     .Attr("overlapping: bool = false")
1562     .Attr("T: {float, double, int32, int64}")
__anon84cbdd652302(InferenceContext* c) 1563     .SetShapeFn([](InferenceContext* c) {
1564       if (c->input_tensor(0) != nullptr) {
1565         ShapeHandle out;
1566         TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1567         c->set_output(0, out);
1568       } else {
1569         c->set_output(0, c->UnknownShapeOfRank(4));
1570       }
1571       return OkStatus();
1572     });
1573 
1574 REGISTER_OP("QuantizedAvgPool")
1575     .Input("input: T")
1576     .Input("min_input: float")
1577     .Input("max_input: float")
1578     .Output("output: T")
1579     .Output("min_output: float")
1580     .Output("max_output: float")
1581     .Attr("T: quantizedtype")
1582     .Attr("ksize: list(int)")
1583     .Attr("strides: list(int)")
1584     .Attr(GetPaddingAttrString())
1585     .SetShapeFn(shape_inference::QuantizedAvgPoolShape);
1586 
1587 REGISTER_OP("QuantizedBiasAdd")
1588     .Input("input: T1")
1589     .Input("bias: T2")
1590     .Input("min_input: float")
1591     .Input("max_input: float")
1592     .Input("min_bias: float")
1593     .Input("max_bias: float")
1594     .Output("output: out_type")
1595     .Output("min_out: float")
1596     .Output("max_out: float")
1597     .Attr("T1: quantizedtype")
1598     .Attr("T2: quantizedtype")
1599     .Attr("out_type: quantizedtype")
__anon84cbdd652402(InferenceContext* c) 1600     .SetShapeFn([](InferenceContext* c) {
1601       TF_RETURN_IF_ERROR(shape_inference::BiasAddShape(c));
1602       ShapeHandle unused;
1603       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1604       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1605       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1606       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1607       c->set_output(1, c->Scalar());
1608       c->set_output(2, c->Scalar());
1609       return OkStatus();
1610     });
1611 
1612 REGISTER_OP("QuantizedConv2D")
1613     .Input("input: Tinput")
1614     .Input("filter: Tfilter")
1615     .Input("min_input: float")
1616     .Input("max_input: float")
1617     .Input("min_filter: float")
1618     .Input("max_filter: float")
1619     .Output("output: out_type")
1620     .Output("min_output: float")
1621     .Output("max_output: float")
1622     .Attr("Tinput: quantizedtype")
1623     .Attr("Tfilter: quantizedtype")
1624     .Attr("out_type: quantizedtype = DT_QINT32")
1625     .Attr("strides: list(int)")
1626     .Attr(GetPaddingAttrString())
1627     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1628     .SetShapeFn(shape_inference::QuantizedConv2DShape);
1629 
1630 REGISTER_OP("QuantizedMaxPool")
1631     .Input("input: T")
1632     .Input("min_input: float")
1633     .Input("max_input: float")
1634     .Output("output: T")
1635     .Output("min_output: float")
1636     .Output("max_output: float")
1637     .Attr("T: quantizedtype")
1638     .Attr("ksize: list(int)")
1639     .Attr("strides: list(int)")
1640     .Attr(GetPaddingAttrString())
__anon84cbdd652502(InferenceContext* c) 1641     .SetShapeFn([](InferenceContext* c) {
1642       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
1643       ShapeHandle unused;
1644       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1645       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1646       c->set_output(1, c->Scalar());
1647       c->set_output(2, c->Scalar());
1648       return OkStatus();
1649     });
1650 
1651 REGISTER_OP("QuantizedRelu")
1652     .Input("features: Tinput")
1653     .Input("min_features: float")
1654     .Input("max_features: float")
1655     .Output("activations: out_type")
1656     .Output("min_activations: float")
1657     .Output("max_activations: float")
1658     .Attr("Tinput: quantizedtype")
1659     .Attr("out_type: quantizedtype = DT_QUINT8")
__anon84cbdd652602(InferenceContext* c) 1660     .SetShapeFn([](InferenceContext* c) {
1661       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1662       ShapeHandle unused;
1663       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1664       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1665       c->set_output(1, c->Scalar());
1666       c->set_output(2, c->Scalar());
1667       return OkStatus();
1668     });
1669 
1670 REGISTER_OP("QuantizedRelu6")
1671     .Input("features: Tinput")
1672     .Input("min_features: float")
1673     .Input("max_features: float")
1674     .Output("activations: out_type")
1675     .Output("min_activations: float")
1676     .Output("max_activations: float")
1677     .Attr("Tinput: quantizedtype")
1678     .Attr("out_type: quantizedtype = DT_QUINT8")
__anon84cbdd652702(InferenceContext* c) 1679     .SetShapeFn([](InferenceContext* c) {
1680       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1681       ShapeHandle unused;
1682       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1683       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1684       c->set_output(1, c->Scalar());
1685       c->set_output(2, c->Scalar());
1686       return OkStatus();
1687     });
1688 
1689 REGISTER_OP("QuantizedReluX")
1690     .Input("features: Tinput")
1691     .Input("max_value: float")
1692     .Input("min_features: float")
1693     .Input("max_features: float")
1694     .Output("activations: out_type")
1695     .Output("min_activations: float")
1696     .Output("max_activations: float")
1697     .Attr("Tinput: quantizedtype")
1698     .Attr("out_type: quantizedtype = DT_QUINT8")
__anon84cbdd652802(InferenceContext* c) 1699     .SetShapeFn([](InferenceContext* c) {
1700       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1701       ShapeHandle unused;
1702       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1703       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1704       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1705       c->set_output(1, c->Scalar());
1706       c->set_output(2, c->Scalar());
1707       return OkStatus();
1708     });
1709 
1710 REGISTER_OP("QuantizedBatchNormWithGlobalNormalization")
1711     .Input("t: Tinput")
1712     .Input("t_min: float")
1713     .Input("t_max: float")
1714     .Input("m: Tinput")
1715     .Input("m_min: float")
1716     .Input("m_max: float")
1717     .Input("v: Tinput")
1718     .Input("v_min: float")
1719     .Input("v_max: float")
1720     .Input("beta: Tinput")
1721     .Input("beta_min: float")
1722     .Input("beta_max: float")
1723     .Input("gamma: Tinput")
1724     .Input("gamma_min: float")
1725     .Input("gamma_max: float")
1726     .Output("result: out_type")
1727     .Output("result_min: float")
1728     .Output("result_max: float")
1729     .Attr("Tinput: quantizedtype")
1730     .Attr("out_type: quantizedtype")
1731     .Attr("variance_epsilon: float")
1732     .Attr("scale_after_normalization: bool")
__anon84cbdd652902(InferenceContext* c) 1733     .SetShapeFn([](InferenceContext* c) {
1734       ShapeHandle input;
1735       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
1736 
1737       DimensionHandle last_dim = c->Dim(input, 3);
1738       for (int i = 1; i < 5; ++i) {  // covers m, v, beta, gamma
1739         ShapeHandle vec;
1740         TF_RETURN_IF_ERROR(c->WithRank(c->input(i * 3), 1, &vec));
1741         TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
1742       }
1743 
1744       ShapeHandle out;
1745       TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
1746       c->set_output(0, out);
1747       c->set_output(1, c->Scalar());
1748       c->set_output(2, c->Scalar());
1749 
1750       return OkStatus();
1751     });
1752 
1753 #ifdef INTEL_MKL
1754 REGISTER_OP("_MklDepthwiseConv2dNative")
1755     .Input("input: T")
1756     .Input("filter: T")
1757     .Input("mkl_input: uint8")
1758     .Input("mkl_filter: uint8")
1759     .Output("output: T")
1760     .Output("filter_output: T")
1761     .Output("mkl_output: uint8")
1762     .Output("mkl_filter_output: uint8")
1763     .Attr("T: {half, bfloat16, float, double}")
1764     .Attr("strides: list(int)")
1765     .Attr("is_filter_const: bool = false")
1766     .Attr(GetPaddingAttrStringWithExplicit())
1767     .Attr(GetConvnetDataFormatAttrString())
1768     .Attr(GetExplicitPaddingsAttrString())
1769     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1770     .SetShapeFn(shape_inference::DepthwiseConv2DNativeShapeWithExplicitPadding);
1771 
1772 REGISTER_OP("_MklConv2D")
1773     .Input("input: T")
1774     .Input("filter: T")
1775     .Input("mkl_input: uint8")
1776     .Input("mkl_filter: uint8")
1777     .Output("output: T")
1778     .Output("filter_output: T")
1779     .Output("mkl_output: uint8")
1780     .Output("mkl_filter_output: uint8")
1781     .Attr("T: {bfloat16, float}")
1782     .Attr("strides: list(int)")
1783     .Attr("use_cudnn_on_gpu: bool = true")
1784     .Attr("is_filter_const: bool = false")
1785     .Attr(GetPaddingAttrStringWithExplicit())
1786     .Attr(GetConvnetDataFormatAttrString())
1787     .Attr(GetExplicitPaddingsAttrString())
1788     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1789     .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1790     .Doc(R"doc(
1791 MKL version of Conv2D operator. Uses MKL DNN APIs to perform 2D convolution.
1792 
1793 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1794 expected to invoke these operators.
1795 )doc");
1796 
1797 REGISTER_OP("_MklNativeConv2D")
1798     .Input("input: T")
1799     .Input("filter: T")
1800     .Output("output: T")
1801     .Attr("T: {bfloat16, float}")
1802     .Attr("strides: list(int)")
1803     .Attr("use_cudnn_on_gpu: bool = true")
1804     .Attr("is_filter_const: bool = false")
1805     .Attr(GetPaddingAttrStringWithExplicit())
1806     .Attr(GetExplicitPaddingsAttrString())
1807     .Attr(GetConvnetDataFormatAttrString())
1808     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1809     .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1810     .Doc(R"doc(
1811     MKL version of Conv2D operator for Eager mode. Uses MKL DNN APIs to perform 2D convolution.
1812 
1813     NOTE Do not invoke this operator directly in Python. Eager Op rewrite is
1814     expected to invoke these operators.
1815     )doc");
1816 
1817 REGISTER_OP("__MklDummyConv2DWithBias")
1818     .Input("input: T")
1819     .Input("filter: T")
1820     .Input("bias: T")
1821     .Output("output: T")
1822     .Attr("T: {bfloat16, float}")
1823     .Attr("strides: list(int)")
1824     .Attr("use_cudnn_on_gpu: bool = true")
1825     .Attr("is_filter_const: bool = false")
1826     .Attr(GetPaddingAttrStringWithExplicit())
1827     .Attr(GetExplicitPaddingsAttrString())
1828     .Attr(GetConvnetDataFormatAttrString())
1829     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1830     .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1831     .Doc(R"doc(
1832 Dummy node that enables fusing Conv2D and BiasAdd operator for MKL. This node
1833 does not perform anything. It is just created as an intermediate output of
1834 merging Conv2D and BiasAdd.
1835 
1836 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1837 expected to invoke these operators.
1838 )doc");
1839 
1840 REGISTER_OP("_MklConv2DWithBias")
1841     .Input("input: T")
1842     .Input("filter: T")
1843     .Input("bias: T")
1844     .Input("mkl_input: uint8")
1845     .Input("mkl_filter: uint8")
1846     .Input("mkl_bias: uint8")
1847     .Output("output: T")
1848     .Output("filter_output: T")
1849     .Output("mkl_output: uint8")
1850     .Output("mkl_filter_output: uint8")
1851     .Attr("T: {bfloat16, float}")
1852     .Attr("strides: list(int)")
1853     .Attr("use_cudnn_on_gpu: bool = true")
1854     .Attr("is_filter_const: bool = false")
1855     .Attr(GetPaddingAttrStringWithExplicit())
1856     .Attr(GetExplicitPaddingsAttrString())
1857     .Attr(GetConvnetDataFormatAttrString())
1858     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1859     .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1860     .Doc(R"doc(
1861 MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform
1862 2D convolution and add Bias to the output of convolution.
1863 
1864 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1865 expected to invoke these operators.
1866 )doc");
1867 
1868 REGISTER_OP("__MklDummyPadWithConv2D")
1869     .Input("input: T")
1870     .Input("filter: T")
1871     .Input("paddings: Tpaddings")
1872     .Output("output: T")
1873     .Attr("T: {bfloat16, float}")
1874     .Attr("strides: list(int)")
1875     .Attr("use_cudnn_on_gpu: bool = true")
1876     .Attr(GetPaddingAttrString())
1877     .Attr(GetConvnetDataFormatAttrString())
1878     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1879     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1880     .SetShapeFn(shape_inference::Conv2DShape)
1881     .Doc(R"doc(
1882 Dummy node that enables fusing Pad and Conv2D operator for MKL. This node
1883 does not perform anything. It is just created as an intermediate output of
1884 merging Pad and Conv2D.
1885 
1886 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1887 expected to invoke these operators.
1888 )doc");
1889 
1890 REGISTER_OP("_MklPadWithConv2D")
1891     .Input("input: T")
1892     .Input("filter: T")
1893     .Input("paddings: Tpaddings")
1894     .Input("mkl_input: uint8")
1895     .Input("mkl_filter: uint8")
1896     .Input("mkl_paddings: uint8")
1897     .Output("output: T")
1898     .Output("filter_output: T")
1899     .Output("mkl_output: uint8")
1900     .Output("mkl_filter_output: uint8")
1901     .Attr("T: {bfloat16, float}")
1902     .Attr("strides: list(int)")
1903     .Attr("use_cudnn_on_gpu: bool = true")
1904     .Attr(GetPaddingAttrString())
1905     .Attr(GetConvnetDataFormatAttrString())
1906     .Attr("is_filter_const: bool = false")
1907     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1908     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1909     .SetShapeFn(shape_inference::Conv2DShape)
1910     .Doc(R"doc(
1911 MKL version of Pad and Conv2D operator. Uses MKL DNN APIs to perform
1912 Pad and 2D convolution to the output of convolution.
1913 
1914 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1915 expected to invoke these operators.
1916 )doc");
1917 
1918 REGISTER_OP("_MklConv2DBackpropFilter")
1919     .Input("input: T")
1920     .Input("filter_sizes: int32")
1921     .Input("out_backprop: T")
1922     .Input("mkl_input: uint8")
1923     .Input("mkl_filter_size: uint8")
1924     .Input("mkl_out_backprop: uint8")
1925     .Output("output: T")
1926     .Output("mkl_output: uint8")
1927     .Attr("T: {bfloat16, float}")
1928     .Attr("strides: list(int)")
1929     .Attr("use_cudnn_on_gpu: bool = true")
1930     .Attr(GetPaddingAttrString())
1931     .Attr(GetConvnetDataFormatAttrString())
1932     .Attr(GetExplicitPaddingsAttrString())
1933     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd652a02(InferenceContext* c) 1934     .SetShapeFn([](InferenceContext* c) {
1935       ShapeHandle s;
1936       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
1937       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1938       c->set_output(0, s);
1939       return Status::OK();
1940     })
1941     .Doc(R"doc(
1942 MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
1943 gradients of convolution with respect to the filter.
1944 
1945 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1946 expected to invoke these operators.
1947 )doc");
1948 
1949 REGISTER_OP("_MklNativeConv2DBackpropFilter")
1950     .Input("input: T")
1951     .Input("filter_sizes: int32")
1952     .Input("out_backprop: T")
1953     .Output("output: T")
1954     .Attr("T: {bfloat16, float}")
1955     .Attr("strides: list(int)")
1956     .Attr("use_cudnn_on_gpu: bool = true")
1957     .Attr(GetPaddingAttrStringWithExplicit())
1958     .Attr(GetExplicitPaddingsAttrString())
1959     .Attr(GetConvnetDataFormatAttrString())
1960     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd652b02(InferenceContext* c) 1961     .SetShapeFn([](InferenceContext* c) {
1962       ShapeHandle s;
1963       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
1964       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1965       c->set_output(0, s);
1966       return Status::OK();
1967     })
1968     .Doc(R"doc(
1969 MKL version of Conv2DBackpropFilter for Eager mode. Uses MKL DNN APIs
1970 to compute the gradients of convolution with respect to the filter.
1971 
1972 NOTE Do not invoke this operator directly in Python. Eager Op rewrite pass is
1973 expected to invoke these operators.
1974 )doc");
1975 
1976 REGISTER_OP("__MklDummyConv2DBackpropFilterWithBias")
1977     .Input("input: T")
1978     .Input("filter_sizes: int32")
1979     .Input("out_backprop: T")
1980     .Output("output: T")
1981     .Output("bias_grad: T")
1982     .Attr("T: {bfloat16, float}")
1983     .Attr("strides: list(int)")
1984     .Attr("use_cudnn_on_gpu: bool = true")
1985     .Attr(GetPaddingAttrString())
1986     .Attr(GetConvnetDataFormatAttrString())
1987     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd652c02(InferenceContext* c) 1988     .SetShapeFn([](InferenceContext* c) {
1989       ShapeHandle input_shape;
1990       // Fetch the data_format attribute, which may not exist.
1991       string data_format;
1992       Status s = c->GetAttr("data_format", &data_format);
1993 
1994       if (s.ok() && data_format == "NCHW") {
1995         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1996         c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
1997       } else {
1998         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1999         c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
2000       }
2001       ShapeHandle sh;
2002       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
2003       TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
2004       c->set_output(0, sh);
2005       return Status::OK();
2006     })
2007     .Doc(R"doc(
2008 Dummy node that enables fusing Conv2DBackpropFilter and BiasAddGrad operator
2009 for MKL. This node does not perform anything. It is just created as an
2010 intermediate output of merging Conv2DBackpropFilter and BiasAddGrad.
2011 
2012 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2013 expected to invoke these operators.
2014 )doc");
2015 
2016 REGISTER_OP("_MklConv2DBackpropFilterWithBias")
2017     .Input("input: T")
2018     .Input("filter_sizes: int32")
2019     .Input("out_backprop: T")
2020     .Input("mkl_input: uint8")
2021     .Input("mkl_filter_size: uint8")
2022     .Input("mkl_out_backprop: uint8")
2023     .Output("output: T")
2024     .Output("bias_grad: T")
2025     .Output("mkl_output: uint8")
2026     .Output("mkl_bias_grad: uint8")
2027     .Attr("T: {bfloat16, float}")
2028     .Attr("strides: list(int)")
2029     .Attr("use_cudnn_on_gpu: bool = true")
2030     .Attr(GetPaddingAttrString())
2031     .Attr(GetConvnetDataFormatAttrString())
2032     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2033     .SetShapeFn(shape_inference::Conv2DBackpropFilterWithBiasShape)
2034     .Doc(R"doc(
2035 MKL version of Conv2DBackpropFilterWithBias. Uses MKL DNN APIs to compute the
2036 gradients of convolution with respect to the filter.
2037 
2038 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2039 expected to invoke these operators.
2040 )doc");
2041 
2042 #ifdef INTEL_MKL_ML_ONLY
2043 REGISTER_OP("_MklConv2DWithBiasBackpropBias")
2044     .Input("out_backprop: T")
2045     .Input("mkl_out_backprop: uint8")
2046     .Output("output: T")
2047     .Output("mkl_output: uint8")
2048     .Attr("T: {half, float, double}")
2049     .Attr("strides: list(int)")
2050     .Attr(GetConvnetDataFormatAttrString())
2051     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2052     .Doc(R"doc(
2053 MKL version of Conv2DBackpropBias. Uses MKL DNN APIs to compute the
2054 gradients of convolution with respect to the bias.
2055 
2056 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2057 expected to invoke these operators.
2058 )doc");
2059 #endif
2060 
2061 REGISTER_OP("_MklConv2DBackpropInput")
2062     .Input("input_sizes: int32")
2063     .Input("filter: T")
2064     .Input("out_backprop: T")
2065     .Input("mkl_input_sizes: uint8")
2066     .Input("mkl_filter: uint8")
2067     .Input("mkl_out_backprop: uint8")
2068     .Output("output: T")
2069     .Output("mkl_output: uint8")
2070     .Attr("T: {bfloat16, float}")
2071     .Attr("strides: list(int)")
2072     .Attr("use_cudnn_on_gpu: bool = true")
2073     .Attr(GetPaddingAttrString())
2074     .Attr(GetConvnetDataFormatAttrString())
2075     .Attr(GetExplicitPaddingsAttrString())
2076     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd652d02(InferenceContext* c) 2077     .SetShapeFn([](InferenceContext* c) {
2078       ShapeHandle s;
2079       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2080       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
2081       c->set_output(0, s);
2082       return Status::OK();
2083     })
2084     .Doc(R"doc(
2085 MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the
2086 gradients of convolution with respect to the input.
2087 
2088 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2089 expected to invoke these operators.
2090 )doc");
2091 
2092 REGISTER_OP("_MklNativeConv2DBackpropInput")
2093     .Input("input_sizes: int32")
2094     .Input("filter: T")
2095     .Input("out_backprop: T")
2096     .Output("output: T")
2097     .Attr("T: {bfloat16, float}")
2098     .Attr("strides: list(int)")
2099     .Attr("use_cudnn_on_gpu: bool = true")
2100     .Attr(GetPaddingAttrStringWithExplicit())
2101     .Attr(GetExplicitPaddingsAttrString())
2102     .Attr(GetConvnetDataFormatAttrString())
2103     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd652e02(InferenceContext* c) 2104     .SetShapeFn([](InferenceContext* c) {
2105       ShapeHandle s;
2106       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2107       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
2108       c->set_output(0, s);
2109       return Status::OK();
2110     })
2111     .Doc(R"doc(
2112 MKL version of Convolution2D backward input for Eager mode. Uses MKL DNN APIs
2113 to compute the gradients of convolution with respect to the input.
2114 
2115 NOTE Do not invoke this operator directly in Python. Eager op rewrite is
2116 expected to invoke these operators.
2117 )doc");
2118 
2119 REGISTER_OP("_MklConv3D")
2120     .Input("input: T")
2121     .Input("filter: T")
2122     .Input("mkl_input: uint8")
2123     .Input("mkl_filter: uint8")
2124     .Output("output: T")
2125     .Output("filter_output: T")
2126     .Output("mkl_output: uint8")
2127     .Output("mkl_filter_output: uint8")
2128     .Attr("T: {bfloat16, float}")
2129     .Attr("strides: list(int) >= 5")
2130     .Attr("is_filter_const: bool = false")
2131     .Attr(GetPaddingAttrString())
2132     .Attr(GetConvnet3dDataFormatAttrString())
2133     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
2134     .SetShapeFn(shape_inference::Conv3DShape)
2135     .Doc(R"doc(
2136 MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution.
2137 
2138 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2139 expected to invoke these operators.
2140 )doc");
2141 
2142 REGISTER_OP("_MklConv3DBackpropInputV2")
2143     .Input("input_sizes: Tshape")
2144     .Input("filter: T")
2145     .Input("out_backprop: T")
2146     .Input("mkl_input_sizes: uint8")
2147     .Input("mkl_filter: uint8")
2148     .Input("mkl_out_backprop: uint8")
2149     .Output("output: T")
2150     .Output("mkl_output: uint8")
2151     .Attr("T: {bfloat16, float}")
2152     .Attr("strides: list(int) >= 5")
2153     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
2154     .Attr("Tshape: {int32, int64} = DT_INT32")
2155     .Attr(GetPaddingAttrString())
2156     .Attr(GetConvnet3dDataFormatAttrString())
__anon84cbdd652f02(InferenceContext* c) 2157     .SetShapeFn([](InferenceContext* c) {
2158       ShapeHandle s;
2159       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2160       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
2161       c->set_output(0, s);
2162       return Status::OK();
2163     })
2164     .Doc(R"doc(
2165 MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the
2166 gradients of convolution with respect to the input.
2167 
2168 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2169 expected to invoke these operators.
2170 )doc");
2171 
2172 REGISTER_OP("_MklConv3DBackpropFilterV2")
2173     .Input("input: T")
2174     .Input("filter_sizes: int32")
2175     .Input("out_backprop: T")
2176     .Input("mkl_input: uint8")
2177     .Input("mkl_filter_size: uint8")
2178     .Input("mkl_out_backprop: uint8")
2179     .Output("output: T")
2180     .Output("mkl_output: uint8")
2181     .Attr("T: {bfloat16, float}")
2182     .Attr("strides: list(int)")
2183     .Attr(GetPaddingAttrString())
2184     .Attr(GetConvnet3dDataFormatAttrString())
2185     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon84cbdd653002(InferenceContext* c) 2186     .SetShapeFn([](InferenceContext* c) {
2187       ShapeHandle s;
2188       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
2189       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
2190       c->set_output(0, s);
2191       return Status::OK();
2192     })
2193     .Doc(R"doc(
2194 MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the
2195 gradients of convolution with respect to the filter.
2196 
2197 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2198 expected to invoke these operators.
2199 )doc");
2200 
2201 REGISTER_OP("_MklRelu")
2202     .Input("features: T")
2203     .Input("mkl_features: uint8")
2204     .Output("activations: T")
2205     .Output("mkl_activations: uint8")
2206     .Attr("T: {float, bfloat16} = DT_FLOAT")
2207     .SetShapeFn(shape_inference::UnchangedShape)
2208     .Doc(R"doc(
2209 MKL version of Relu operator. Uses MKL DNN APIs to implement Relu operator.
2210 
2211 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2212 expected to invoke these operators.
2213 )doc");
2214 
2215 REGISTER_OP("_MklReluGrad")
2216     .Input("gradients: T")
2217     .Input("features: T")
2218     .Input("mkl_gradients: uint8")
2219     .Input("mkl_features: uint8")
2220     .Output("backprops: T")
2221     .Output("mkl_backprops: uint8")
2222     .Attr("T: {float, bfloat16} = DT_FLOAT")
2223     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2224     .Doc(R"doc(
2225 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
2226 linear gradients for Relu operation.
2227 
2228 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2229 expected to invoke these operators.
2230 )doc");
2231 
2232 REGISTER_OP("_MklRelu6")
2233     .Input("features: T")
2234     .Input("mkl_features: uint8")
2235     .Output("activations: T")
2236     .Output("mkl_activations: uint8")
2237     .Attr("T: {float, bfloat16} = DT_FLOAT")
2238     .SetShapeFn(shape_inference::UnchangedShape)
2239     .Doc(R"doc(
2240 MKL version of Relu6 operator. Uses MKL DNN APIs to implement Relu6 operator.
2241 
2242 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2243 expected to invoke these operators.
2244 )doc");
2245 
2246 REGISTER_OP("_MklRelu6Grad")
2247     .Input("gradients: T")
2248     .Input("features: T")
2249     .Input("mkl_gradients: uint8")
2250     .Input("mkl_features: uint8")
2251     .Output("backprops: T")
2252     .Output("mkl_backprops: uint8")
2253     .Attr("T: {float, bfloat16} = DT_FLOAT")
2254     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2255     .Doc(R"doc(
2256 MKL version of Relu6Grad operator. Uses MKL DNN APIs to compute rectified
2257 linear gradients for Relu6 operation.
2258 
2259 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2260 expected to invoke these operators.
2261 )doc");
2262 
2263 REGISTER_OP("_MklLeakyRelu")
2264     .Input("features: T")
2265     .Input("mkl_features: uint8")
2266     .Output("activations: T")
2267     .Output("mkl_activations: uint8")
2268     .Attr("T: {float, bfloat16} = DT_FLOAT")
2269     .Attr("alpha: float = 0.2")
2270     .SetShapeFn(shape_inference::UnchangedShape)
2271     .Doc(R"doc(
2272 MKL version of LeakyRelu operator. Uses MKL DNN APIs to implement
2273 LeakyRelu operator.
2274 
2275 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2276 expected to invoke these operators.
2277 )doc");
2278 
2279 REGISTER_OP("_MklLeakyReluGrad")
2280     .Input("gradients: T")
2281     .Input("features: T")
2282     .Input("mkl_gradients: uint8")
2283     .Input("mkl_features: uint8")
2284     .Output("backprops: T")
2285     .Output("mkl_backprops: uint8")
2286     .Attr("T: {float, bfloat16} = DT_FLOAT")
2287     .Attr("alpha: float = 0.2")
2288     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2289     .Doc(R"doc(
2290 MKL version of LeakyReluGrad operator. Uses MKL DNN APIs to compute rectified
2291 linear gradients for LeakyReluGrad operation.
2292 
2293 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2294 expected to invoke these operators.
2295 )doc");
2296 
2297 REGISTER_OP("_MklElu")
2298     .Input("features: T")
2299     .Input("mkl_features: uint8")
2300     .Output("activations: T")
2301     .Output("mkl_activations: uint8")
2302     .Attr("T: {float, bfloat16} = DT_FLOAT")
2303     .SetShapeFn(shape_inference::UnchangedShape)
2304     .Doc(R"doc(
2305 MKL version of Elu operator. Uses MKL DNN APIs to implement Elu operator.
2306 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2307 expected to invoke these operators.
2308 )doc");
2309 
2310 REGISTER_OP("_MklEluGrad")
2311     .Input("gradients: T")
2312     .Input("features: T")
2313     .Input("mkl_gradients: uint8")
2314     .Input("mkl_features: uint8")
2315     .Output("backprops: T")
2316     .Output("mkl_backprops: uint8")
2317     .Attr("T: {float, bfloat16} = DT_FLOAT")
2318     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2319     .Doc(R"doc(
2320 MKL version of EluGrad operator. Uses MKL DNN APIs to compute Elu
2321 gradients for Elu operation.
2322 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2323 expected to invoke these operators.
2324 )doc");
2325 
2326 REGISTER_OP("_MklSoftmax")
2327     .Input("logits: T")
2328     .Input("mkl_logits: uint8")
2329     .Output("softmax: T")
2330     .Output("mkl_softmax: uint8")
2331     .Attr("T: {bfloat16, half, float, double}")
__anon84cbdd653102(InferenceContext* c) 2332     .SetShapeFn([](InferenceContext* c) {
2333       return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
2334     })
2335     .Doc(R"doc(
2336 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
2337 linear gradients for Relu operation.
2338 )doc");
2339 
2340 REGISTER_OP("_MklTanh")
2341     .Input("features: T")
2342     .Input("mkl_features: uint8")
2343     .Output("activations: T")
2344     .Output("mkl_activations: uint8")
2345     .Attr("T: realnumbertype")
2346     .SetShapeFn(shape_inference::UnchangedShape)
2347     .Doc(R"doc(
2348 MKL version of Tanh operator. Uses MKL DNN APIs to implement Tanh operator.
2349 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2350 expected to invoke these operators.
2351 )doc");
2352 
2353 REGISTER_OP("_MklTanhGrad")
2354     .Input("gradients: T")
2355     .Input("features: T")
2356     .Input("mkl_gradients: uint8")
2357     .Input("mkl_features: uint8")
2358     .Output("backprops: T")
2359     .Output("mkl_backprops: uint8")
2360     .Attr("T: realnumbertype")
2361     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2362     .Doc(R"doc(
2363 MKL version of TanhGrad operator. Uses MKL DNN APIs to compute tanh
2364 gradients for Tanh operation.
2365 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2366 expected to invoke these operators.
2367 )doc");
2368 
2369 REGISTER_OP("_MklMaxPool")
2370     .Attr("T: {float, half, bfloat16} = DT_FLOAT")
2371     .Attr("ksize: list(int) >= 4")
2372     .Attr("strides: list(int) >= 4")
2373     .Attr(GetPaddingAttrString())
2374     .Attr(GetConvnetDataFormatAttrString())
2375     .Attr(GetExplicitPaddingsAttrString())
2376     .Attr("workspace_enabled: bool = false")
2377     .Input("input: T")
2378     .Input("mkl_input: uint8")
2379     .Output("output: T")
2380 #ifdef INTEL_MKL_ML_ONLY
2381     .Output("workspace: T")
2382 #else
2383     .Output("workspace: uint8")
2384 #endif
2385     .Output("mkl_output: uint8")
2386     .Output("mkl_workspace: uint8")
2387     .SetShapeFn(shape_inference::MaxPoolShape)
2388     .Doc(R"doc(
2389 MKL version of MaxPool operator. Uses MKL DNN APIs to perform max pooling
2390 on the input.
2391 
2392 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2393 expected to invoke these operators.
2394 )doc");
2395 
2396 REGISTER_OP("_MklMaxPoolGrad")
2397     .Attr("T: {float, half, bfloat16} = DT_FLOAT")
2398     .Attr("ksize: list(int) >= 4")
2399     .Attr("strides: list(int) >= 4")
2400     .Attr("workspace_enabled: bool = false")
2401     .Attr(GetPaddingAttrString())
2402     .Attr(GetConvnetDataFormatAttrString())
2403     .Attr(GetExplicitPaddingsAttrString())
2404     .Input("orig_input: T")
2405     .Input("orig_output: T")
2406     .Input("grad: T")
2407 #ifdef INTEL_MKL_ML_ONLY
2408     .Input("workspace: T")
2409 #else
2410     .Input("workspace: uint8")
2411 #endif
2412     .Input("mkl_orig_input: uint8")
2413     .Input("mkl_orig_output: uint8")
2414     .Input("mkl_grad: uint8")
2415     .Input("mkl_workspace: uint8")
2416     .Output("output: T")
2417     .Output("mkl_output: uint8")
2418     .SetShapeFn(shape_inference::MaxPoolGradShape)
2419     .Doc(R"doc(
2420 oneDNN version of MaxPoolGrad. Uses oneDNN APIs to compute gradients of
2421 MaxPool operator.
2422 
2423 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2424 expected to invoke these operators.
2425 )doc");
2426 
2427 REGISTER_OP("_MklAvgPool")
2428     .Input("value: T")
2429     .Input("mkl_input: uint8")
2430     .Output("output: T")
2431     .Output("mkl_output: uint8")
2432     .Attr("ksize: list(int) >= 4")
2433     .Attr("strides: list(int) >= 4")
2434     .Attr(GetPaddingAttrString())
2435     .Attr(GetConvnetDataFormatAttrString())
2436     .Attr("T: {float, half, double, bfloat16}")
2437     .SetShapeFn(shape_inference::AvgPoolShape)
2438     .Doc(R"doc(
2439 MKL version of AvgPool operator. Uses MKL DNN APIs to perform average pooling
2440 on the input.
2441 
2442 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2443 expected to invoke these operators.
2444 )doc");
2445 
2446 REGISTER_OP("_MklAvgPoolGrad")
2447     .Input("orig_input_shape: int32")
2448     .Input("grad: T")
2449     .Input("mkl_orig_input: uint8")
2450     .Input("mkl_grad: uint8")
2451     .Output("output: T")
2452     .Output("mkl_output: uint8")
2453     .Attr("ksize: list(int) >= 4")
2454     .Attr("strides: list(int) >= 4")
2455     .Attr(GetPaddingAttrString())
2456     .Attr(GetConvnetDataFormatAttrString())
2457     .Attr("T: {float, half, double, bfloat16}")
2458     .SetShapeFn(shape_inference::AvgPoolGradShape)
2459     .Doc(R"doc(
2460 oneDNN version of AvgPoolGrad operator. Uses oneDNN APIs to compute gradients
2461 of AvgPool function.
2462 
2463 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2464 expected to invoke these operators.
2465 )doc");
2466 
2467 REGISTER_OP("_MklAvgPool3D")
2468     .Input("value: T")
2469     .Input("mkl_input: uint8")
2470     .Output("output: T")
2471     .Output("mkl_output: uint8")
2472     .Attr("ksize: list(int) >= 5")
2473     .Attr("strides: list(int) >= 5")
2474     .Attr(GetPaddingAttrString())
2475     .Attr(GetConvnet3dDataFormatAttrString())
2476     .Attr("T: {float, half, double, bfloat16}")
2477     .SetShapeFn(shape_inference::Pool3DShape)
2478     .Doc(R"doc(
2479 MKL version of AvgPool3D operator. Uses MKL DNN APIs to perform average pooling
2480 on the input.
2481 
2482 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2483 expected to invoke these operators.
2484 )doc");
2485 
2486 REGISTER_OP("_MklAvgPool3DGrad")
2487     .Input("orig_input_shape: int32")
2488     .Input("grad: T")
2489     .Input("mkl_orig_input: uint8")
2490     .Input("mkl_grad: uint8")
2491     .Output("output: T")
2492     .Output("mkl_output: uint8")
2493     .Attr("ksize: list(int) >= 5")
2494     .Attr("strides: list(int) >= 5")
2495     .Attr(GetPaddingAttrString())
2496     .Attr(GetConvnet3dDataFormatAttrString())
2497     .Attr("T: {float, half, double, bfloat16}")
2498     .SetShapeFn(shape_inference::AvgPool3DGradShape)
2499     .Doc(R"doc(
2500 oneDNN version of AvgPool3DGrad operator. Uses oneDNN APIs to compute gradients
2501 of AvgPool function.
2502 
2503 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2504 expected to invoke these operators.
2505 )doc");
2506 
2507 REGISTER_OP("_MklMaxPool3D")
2508     .Input("input: T")
2509     .Input("mkl_input: uint8")
2510     .Output("output: T")
2511     .Output("workspace: uint8")
2512     .Output("mkl_output: uint8")
2513     .Output("mkl_workspace: uint8")
2514     .Attr("ksize: list(int) >= 5")
2515     .Attr("strides: list(int) >= 5")
2516     .Attr(GetPaddingAttrString())
2517     .Attr(GetConvnet3dDataFormatAttrString())
2518     .Attr("T: {half, bfloat16, float}")
2519     .Attr("workspace_enabled: bool = false")
2520     .SetShapeFn(shape_inference::Pool3DShape)
2521     .Doc(R"doc(
2522 MKL version of MaxPool3D operator. Uses MKL DNN APIs to perform average pooling
2523 on the input.
2524 
2525 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2526 expected to invoke these operators.
2527 )doc");
2528 
2529 REGISTER_OP("_MklMaxPool3DGrad")
2530     .Input("orig_input: TInput")
2531     .Input("orig_output: TInput")
2532     .Input("grad: T")
2533     .Input("workspace: uint8")
2534     .Input("mkl_orig_input: uint8")
2535     .Input("mkl_orig_output: uint8")
2536     .Input("mkl_grad: uint8")
2537     .Input("mkl_workspace: uint8")
2538     .Output("output: T")
2539     .Output("mkl_output: uint8")
2540     .Attr("ksize: list(int) >= 5")
2541     .Attr("strides: list(int) >= 5")
2542     .Attr(GetPaddingAttrString())
2543     .Attr(GetConvnet3dDataFormatAttrString())
2544     .Attr("T: {half, bfloat16, float} = DT_FLOAT")
2545     .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
2546     .Attr("workspace_enabled: bool = false")
2547     .SetShapeFn(shape_inference::MaxPool3DGradShape)
2548     .Doc(R"doc(
2549 oneDNN version of MaxPool3DGrad operator. Uses oneDNN APIs to compute gradients
2550 of MaxPool3D function.
2551 
2552 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2553 expected to invoke these operators.
2554 )doc");
2555 
2556 REGISTER_OP("_MklLRN")
2557     .Input("input: T")
2558     .Input("mkl_input: uint8")
2559     .Output("output: T")
2560     .Output("workspace: uint8")
2561     .Output("mkl_output: uint8")
2562     .Output("mkl_workspace: uint8")
2563     .Attr("depth_radius: int = 5")
2564     .Attr("bias: float = 1.0")
2565     .Attr("alpha: float = 1.0")
2566     .Attr("beta: float = 0.5")
2567     .Attr("workspace_enabled: bool = false")
2568     .Attr("T: {float, half} = DT_FLOAT")
__anon84cbdd653202(InferenceContext* c) 2569     .SetShapeFn([](InferenceContext* c) {
2570       return UnchangedShapeWithRank(c, 4);
2571     })
2572     .Doc(R"doc(
2573 MKL version of LRN operator. Uses MKL DNN APIs to perform local response
2574 normalization.
2575 
2576 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2577 expected to invoke these operators.
2578 )doc");
2579 
2580 REGISTER_OP("_MklLRNGrad")
2581     .Input("input_grads: T")
2582     .Input("input_image: T")
2583     .Input("output_image: T")
2584     .Input("workspace: uint8")
2585     .Input("mkl_input_grads: uint8")
2586     .Input("mkl_input_image: uint8")
2587     .Input("mkl_output_image: uint8")
2588     .Input("mkl_workspace: uint8")
2589     .Output("output: T")
2590     .Output("mkl_output: uint8")
2591     .Attr("depth_radius: int = 5")
2592     .Attr("bias: float = 1.0")
2593     .Attr("alpha: float = 1.0")
2594     .Attr("beta: float = 0.5")
2595     .Attr("workspace_enabled: bool = false")
2596     .Attr("T: {float, half} = DT_FLOAT")
__anon84cbdd653302(InferenceContext* c) 2597     .SetShapeFn([](InferenceContext* c) {
2598       ShapeHandle s;
2599       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s));  // input_grads
2600       TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s));     // input_image
2601       TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s));     // output_image
2602       c->set_output(0, s);
2603       return Status::OK();
2604     })
2605     .Doc(R"doc(
2606 MKL version of LRNGrad operator. Uses MKL DNN APIs to compute gradient for
2607 local response normalization.
2608 
2609 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2610 expected to invoke these operators.
2611 )doc");
2612 
2613 REGISTER_OP("_MklFusedBatchNorm")
2614     .Input("x: T")
2615     .Input("scale: T")
2616     .Input("offset: T")
2617     .Input("mean: T")
2618     .Input("variance: T")
2619     .Input("mkl_x: uint8")
2620     .Input("mkl_scale: uint8")
2621     .Input("mkl_offset: uint8")
2622     .Input("mkl_mean: uint8")
2623     .Input("mkl_variance: uint8")
2624     .Output("y: T")
2625     .Output("batch_mean: T")
2626     .Output("batch_variance: T")
2627     .Output("reserve_space_1: T")
2628     .Output("reserve_space_2: T")
2629     .Output("mkl_y: uint8")
2630     .Output("mkl_batch_mean: uint8")
2631     .Output("mkl_batch_variance: uint8")
2632     .Output("mkl_reserve_space_1: uint8")
2633     .Output("mkl_reserve_space_2: uint8")
2634     .Attr("T: numbertype")
2635     .Attr("epsilon: float = 0.0001")
2636     .Attr("data_format: string = 'NHWC'")
2637     .Attr("exponential_avg_factor: float = 1.0")
2638     .Attr("is_training: bool = true")
2639     .SetShapeFn(shape_inference::FusedBatchNormShape)
2640     .Doc(R"doc(
2641 oneDNN version of FusedBatchNorm operator. Uses oneDNN APIs to perform fused
2642 batch normalization.
2643 
2644 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2645 expected to invoke these operators.
2646 )doc");
2647 
2648 REGISTER_OP("_MklFusedBatchNormGrad")
2649     .Input("y_backprop: T")
2650     .Input("x: T")
2651     .Input("scale: T")
2652     .Input("reserve_space_1: T")
2653     .Input("reserve_space_2: T")
2654     .Input("mkl_y_backprop: uint8")
2655     .Input("mkl_x: uint8")
2656     .Input("mkl_scale: uint8")
2657     .Input("mkl_reserve_space_1: uint8")
2658     .Input("mkl_reserve_space_2: uint8")
2659     .Output("x_backprop: T")
2660     .Output("scale_backprop: T")
2661     .Output("offset_backprop: T")
2662     .Output("reserve_space_3: T")
2663     .Output("reserve_space_4: T")
2664     .Output("mkl_x_backprop: uint8")
2665     .Output("mkl_scale_backprop: uint8")
2666     .Output("mkl_offset_backprop: uint8")
2667     .Output("mkl_reserve_space_3: uint8")
2668     .Output("mkl_reserve_space_4: uint8")
2669     .Attr("T: numbertype")
2670     .Attr("epsilon: float = 0.0001")
2671     .Attr("data_format: string = 'NHWC'")
2672     .Attr("is_training: bool = true")
2673     .SetShapeFn(shape_inference::FusedBatchNormGradShape)
2674     .Doc(R"doc(
2675 oneDNN version of FusedBatchNormGrad operator. Uses oneDNN APIs to compute
2676 gradients for fused batch normalization.
2677 
2678 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2679 expected to invoke these operators.
2680 )doc");
2681 
2682 REGISTER_OP("_MklFusedBatchNormV2")
2683     .Input("x: T")
2684     .Input("scale: U")
2685     .Input("offset: U")
2686     .Input("mean: U")
2687     .Input("variance: U")
2688     .Input("mkl_x: uint8")
2689     .Input("mkl_scale: uint8")
2690     .Input("mkl_offset: uint8")
2691     .Input("mkl_mean: uint8")
2692     .Input("mkl_variance: uint8")
2693     .Output("y: T")
2694     .Output("batch_mean: U")
2695     .Output("batch_variance: U")
2696     .Output("reserve_space_1: U")
2697     .Output("reserve_space_2: U")
2698     .Output("mkl_y: uint8")
2699     .Output("mkl_batch_mean: uint8")
2700     .Output("mkl_batch_variance: uint8")
2701     .Output("mkl_reserve_space_1: uint8")
2702     .Output("mkl_reserve_space_2: uint8")
2703     .Attr("T: {bfloat16, float}")
2704     .Attr("U: {float}")
2705     .Attr("epsilon: float = 0.0001")
2706     .Attr(GetConvnetDataFormatAttrString())
2707     .Attr("exponential_avg_factor: float = 1.0")
2708     .Attr("is_training: bool = true")
2709     .SetShapeFn(shape_inference::FusedBatchNormShape);
2710 
2711 REGISTER_OP("_MklFusedBatchNormGradV2")
2712     .Input("y_backprop: T")
2713     .Input("x: T")
2714     .Input("scale: float")
2715     .Input("reserve_space_1: U")
2716     .Input("reserve_space_2: U")
2717     .Input("mkl_y_backprop: uint8")
2718     .Input("mkl_x: uint8")
2719     .Input("mkl_scale: uint8")
2720     .Input("mkl_reserve_space_1: uint8")
2721     .Input("mkl_reserve_space_2: uint8")
2722     .Output("x_backprop: T")
2723     .Output("scale_backprop: U")
2724     .Output("offset_backprop: U")
2725     .Output("reserve_space_3: U")
2726     .Output("reserve_space_4: U")
2727     .Output("mkl_x_backprop: uint8")
2728     .Output("mkl_scale_backprop: uint8")
2729     .Output("mkl_offset_backprop: uint8")
2730     .Output("mkl_reserve_space_3: uint8")
2731     .Output("mkl_reserve_space_4: uint8")
2732     .Attr("T: {bfloat16, float}")
2733     .Attr("U: {float}")
2734     .Attr("epsilon: float = 0.0001")
2735     .Attr(GetConvnetDataFormatAttrString())
2736     .Attr("is_training: bool = true")
2737     .SetShapeFn(shape_inference::FusedBatchNormGradShape);
2738 
2739 REGISTER_OP("_MklToTf")
2740     .Input("input: T")
2741     .Input("mkl_input: uint8")
2742     .Output("output: T")
2743     .Attr("T: {half, float, double, bfloat16, qint8, quint8, qint32}")
2744     .Attr(GetConvnetDataFormat2D3DAttrString())
2745     .SetShapeFn(shape_inference::UnknownShape)
2746     .Doc(R"doc(
2747 MKL operator to convert a tensor from MKL layout to TensorFlow layout.
2748 
2749 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2750 expected to invoke these operators.
2751 )doc");
2752 
2753 REGISTER_OP("_MklInputConversion")
2754     .Input("input_0: T")
2755     .Input("input_1: T")
2756     .Input("mkl_input_0: uint8")
2757     .Input("mkl_input_1: uint8")
2758     .Output("output_0: T")
2759     .Output("output_1: T")
2760     .Output("mkl_output_0: uint8")
2761     .Output("mkl_output_1: uint8")
2762     // All datatypes supported by element-wise ops
2763     .Attr(
2764         "T: {half, float, bfloat16, double, uint8, int8, uint16, int16, int32, "
2765         "int64, complex64, complex128}")
2766     .Attr(GetConvnetDataFormat2D3DAttrString())
2767     .SetShapeFn(shape_inference::UnknownShape)
2768     .Doc(R"doc(
2769 MKL operator to process the inputs to an elementwise MKL op. Both inputs
2770 need to be either in TF or in MKL format. This op is added before every
2771 element-wise MKL op.
2772 
2773 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2774 expected to invoke these operators.
2775 )doc");
2776 
2777 #endif  // INTEL_MKL
2778 REGISTER_OP("QuantizedConv2DAndRequantize")
2779     .Input("input: Tinput")
2780     .Input("filter: Tfilter")
2781     .Input("min_input: float")
2782     .Input("max_input: float")
2783     .Input("min_filter: float")
2784     .Input("max_filter: float")
2785     .Input("min_freezed_output: float")
2786     .Input("max_freezed_output: float")
2787     .Output("output: out_type")
2788     .Output("min_output: float")
2789     .Output("max_output: float")
2790     .Attr("Tinput: quantizedtype")
2791     .Attr("Tfilter: quantizedtype")
2792     .Attr("out_type: quantizedtype = DT_QINT8")
2793     .Attr("strides: list(int)")
2794     .Attr(GetPaddingAttrString())
2795     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2796     .Attr("padding_list: list(int) = []")
__anon84cbdd653402(InferenceContext* c) 2797     .SetShapeFn([](InferenceContext* c) {
2798       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2799       ShapeHandle unused;
2800       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2801       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2802       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2803       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2804       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2805       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2806       c->set_output(1, c->Scalar());
2807       c->set_output(2, c->Scalar());
2808       return OkStatus();
2809     });
2810 
2811 // Fusion of Quantized Conv2D and BiasAdd.
2812 REGISTER_OP("QuantizedConv2DWithBias")
2813     .Input("input: Tinput")
2814     .Input("filter: Tfilter")
2815     .Input("bias: float")
2816     .Input("min_input: float")
2817     .Input("max_input: float")
2818     .Input("min_filter: float")
2819     .Input("max_filter: float")
2820     .Output("output: out_type")
2821     .Output("min_output: float")
2822     .Output("max_output: float")
2823     .Attr("Tinput: quantizedtype")
2824     .Attr("Tfilter: quantizedtype")
2825     .Attr("out_type: quantizedtype = DT_QINT32")
2826     .Attr("strides: list(int)")
2827     .Attr(GetPaddingAttrString())
2828     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2829     .Attr("padding_list: list(int) = []")
__anon84cbdd653502(InferenceContext* c) 2830     .SetShapeFn([](InferenceContext* c) {
2831       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2832       ShapeHandle unused, channel;
2833       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2834       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2835       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2836       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2837       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
2838       c->set_output(1, channel);
2839       c->set_output(2, channel);
2840       return OkStatus();
2841     });
2842 
2843 REGISTER_OP("QuantizedConv2DWithBiasAndRequantize")
2844     .Input("input: Tinput")
2845     .Input("filter: Tfilter")
2846     .Input("bias: Tbias")
2847     .Input("min_input: float")
2848     .Input("max_input: float")
2849     .Input("min_filter: float")
2850     .Input("max_filter: float")
2851     .Input("min_freezed_output: float")
2852     .Input("max_freezed_output: float")
2853     .Output("output: out_type")
2854     .Output("min_output: float")
2855     .Output("max_output: float")
2856     .Attr("Tinput: quantizedtype")
2857     .Attr("Tfilter: quantizedtype")
2858     .Attr("Tbias: {float, qint32}")
2859     .Attr("out_type: quantizedtype = DT_QINT8")
2860     .Attr("strides: list(int)")
2861     .Attr(GetPaddingAttrString())
2862     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2863     .Attr("padding_list: list(int) = []")
__anon84cbdd653602(InferenceContext* c) 2864     .SetShapeFn([](InferenceContext* c) {
2865       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2866       ShapeHandle unused, channel;
2867       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2868       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2869       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2870       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2871       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
2872       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2873       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
2874       c->set_output(1, c->Scalar());
2875       c->set_output(2, c->Scalar());
2876       return OkStatus();
2877     });
2878 
2879 // Fusion of Quantized Conv2D and Relu.
2880 REGISTER_OP("QuantizedConv2DAndRelu")
2881     .Input("input: Tinput")
2882     .Input("filter: Tfilter")
2883     .Input("min_input: float")
2884     .Input("max_input: float")
2885     .Input("min_filter: float")
2886     .Input("max_filter: float")
2887     .Output("output: out_type")
2888     .Output("min_output: float")
2889     .Output("max_output: float")
2890     .Attr("Tinput: quantizedtype")
2891     .Attr("Tfilter: quantizedtype")
2892     .Attr("out_type: quantizedtype = DT_QINT32")
2893     .Attr("strides: list(int)")
2894     .Attr(GetPaddingAttrString())
2895     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2896     .Attr("padding_list: list(int) = []")
__anon84cbdd653702(InferenceContext* c) 2897     .SetShapeFn([](InferenceContext* c) {
2898       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2899       ShapeHandle unused, channel;
2900       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2901       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2902       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel));
2903       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2904       c->set_output(1, channel);
2905       c->set_output(2, channel);
2906       return OkStatus();
2907     });
2908 
2909 REGISTER_OP("QuantizedConv2DAndReluAndRequantize")
2910     .Input("input: Tinput")
2911     .Input("filter: Tfilter")
2912     .Input("min_input: float")
2913     .Input("max_input: float")
2914     .Input("min_filter: float")
2915     .Input("max_filter: float")
2916     .Input("min_freezed_output: float")
2917     .Input("max_freezed_output: float")
2918     .Output("output: out_type")
2919     .Output("min_output: float")
2920     .Output("max_output: float")
2921     .Attr("Tinput: quantizedtype")
2922     .Attr("Tfilter: quantizedtype")
2923     .Attr("out_type: quantizedtype = DT_QUINT8")
2924     .Attr("strides: list(int)")
2925     .Attr(GetPaddingAttrString())
2926     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2927     .Attr("padding_list: list(int) = []")
__anon84cbdd653802(InferenceContext* c) 2928     .SetShapeFn([](InferenceContext* c) {
2929       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2930       ShapeHandle unused, channel;
2931       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2932       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2933       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel));
2934       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2935       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2936       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2937       c->set_output(1, c->Scalar());
2938       c->set_output(2, c->Scalar());
2939       return OkStatus();
2940     });
2941 
2942 // Fusion of Quantized Conv2D, BiasAdd and Relu.
2943 REGISTER_OP("QuantizedConv2DWithBiasAndRelu")
2944     .Input("input: Tinput")
2945     .Input("filter: Tfilter")
2946     .Input("bias: float")
2947     .Input("min_input: float")
2948     .Input("max_input: float")
2949     .Input("min_filter: float")
2950     .Input("max_filter: float")
2951     .Output("output: out_type")
2952     .Output("min_output: float")
2953     .Output("max_output: float")
2954     .Attr("Tinput: quantizedtype")
2955     .Attr("Tfilter: quantizedtype")
2956     .Attr("out_type: quantizedtype = DT_QINT32")
2957     .Attr("strides: list(int)")
2958     .Attr(GetPaddingAttrString())
2959     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2960     .Attr("padding_list: list(int) = []")
__anon84cbdd653902(InferenceContext* c) 2961     .SetShapeFn([](InferenceContext* c) {
2962       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2963       ShapeHandle unused, channel;
2964       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2965       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2966       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2967       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2968       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
2969       c->set_output(1, channel);
2970       c->set_output(2, channel);
2971       return OkStatus();
2972     });
2973 
2974 // Fusion of Quantized Conv2D, BiasAdd, Relu, and Requantize.
2975 REGISTER_OP("QuantizedConv2DWithBiasAndReluAndRequantize")
2976     .Input("input: Tinput")
2977     .Input("filter: Tfilter")
2978     .Input("bias: Tbias")
2979     .Input("min_input: float")
2980     .Input("max_input: float")
2981     .Input("min_filter: float")
2982     .Input("max_filter: float")
2983     .Input("min_freezed_output: float")
2984     .Input("max_freezed_output: float")
2985     .Output("output: out_type")
2986     .Output("min_output: float")
2987     .Output("max_output: float")
2988     .Attr("Tinput: quantizedtype")
2989     .Attr("Tfilter: quantizedtype")
2990     .Attr("Tbias: {float, qint32}")
2991     .Attr("out_type: quantizedtype = DT_QUINT8")
2992     .Attr("strides: list(int)")
2993     .Attr(GetPaddingAttrString())
2994     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2995     .Attr("padding_list: list(int) = []")
__anon84cbdd653a02(InferenceContext* c) 2996     .SetShapeFn([](InferenceContext* c) {
2997       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2998       ShapeHandle unused, channel;
2999       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3000       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3001       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3002       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3003       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3004       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3005       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3006       c->set_output(1, c->Scalar());
3007       c->set_output(2, c->Scalar());
3008       return OkStatus();
3009     });
3010 
3011 // Fusion of Quantized Conv2D, BiasAdd, Sum, and Relu.
3012 REGISTER_OP("QuantizedConv2DWithBiasSumAndRelu")
3013     .Input("input: Tinput")
3014     .Input("filter: Tfilter")
3015     .Input("bias: float")
3016     .Input("min_input: float")
3017     .Input("max_input: float")
3018     .Input("min_filter: float")
3019     .Input("max_filter: float")
3020     .Input("summand: float")
3021     .Output("output: out_type")
3022     .Output("min_output: float")
3023     .Output("max_output: float")
3024     .Attr("Tinput: quantizedtype")
3025     .Attr("Tfilter: quantizedtype")
3026     .Attr("out_type: quantizedtype = DT_QINT32")
3027     .Attr("strides: list(int)")
3028     .Attr(GetPaddingAttrString())
3029     .Attr("dilations: list(int) = [1, 1, 1, 1]")
3030     .Attr("padding_list: list(int) = []")
__anon84cbdd653b02(InferenceContext* c) 3031     .SetShapeFn([](InferenceContext* c) {
3032       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3033       ShapeHandle unused, channel;
3034       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3035       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3036       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3037       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3038       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3039       c->set_output(1, channel);
3040       c->set_output(2, channel);
3041       return OkStatus();
3042     });
3043 
3044 REGISTER_OP("QuantizedConv2DWithBiasSumAndReluAndRequantize")
3045     .Input("input: Tinput")
3046     .Input("filter: Tfilter")
3047     .Input("bias: Tbias")
3048     .Input("min_input: float")
3049     .Input("max_input: float")
3050     .Input("min_filter: float")
3051     .Input("max_filter: float")
3052     .Input("min_freezed_output: float")
3053     .Input("max_freezed_output: float")
3054     .Input("summand: Tsummand")
3055     .Input("min_summand: float")
3056     .Input("max_summand: float")
3057     .Output("output: out_type")
3058     .Output("min_output: float")
3059     .Output("max_output: float")
3060     .Attr("Tinput: quantizedtype")
3061     .Attr("Tfilter: quantizedtype")
3062     .Attr("Tbias: {float, qint32}")
3063     .Attr("Tsummand: quantizedtype")
3064     .Attr("out_type: quantizedtype = DT_QUINT8")
3065     .Attr("strides: list(int)")
3066     .Attr(GetPaddingAttrString())
3067     .Attr("dilations: list(int) = [1, 1, 1, 1]")
3068     .Attr("padding_list: list(int) = []")
__anon84cbdd653c02(InferenceContext* c) 3069     .SetShapeFn([](InferenceContext* c) {
3070       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3071       ShapeHandle unused, channel;
3072       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3073       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3074       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3075       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3076       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3077       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3078       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3079       c->set_output(1, c->Scalar());
3080       c->set_output(2, c->Scalar());
3081       return OkStatus();
3082     });
3083 
3084 REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
3085     .Input("input: Tinput")
3086     .Input("filter: Tfilter")
3087     .Input("bias: Tbias")
3088     .Input("min_input: float")
3089     .Input("max_input: float")
3090     .Input("min_filter: float")
3091     .Input("max_filter: float")
3092     .Input("min_freezed_output: float")
3093     .Input("max_freezed_output: float")
3094     .Input("summand: Tsummand")
3095     .Input("min_summand: float")
3096     .Input("max_summand: float")
3097     .Output("output: out_type")
3098     .Output("min_output: float")
3099     .Output("max_output: float")
3100     .Attr("Tinput: quantizedtype")
3101     .Attr("Tfilter: quantizedtype")
3102     .Attr("Tbias: {float, qint32}")
3103     .Attr("Tsummand: quantizedtype")
3104     .Attr("out_type: quantizedtype = DT_QUINT8")
3105     .Attr("strides: list(int)")
3106     .Attr(GetPaddingAttrString())
3107     .Attr("dilations: list(int) = [1, 1, 1, 1]")
3108     .Attr("padding_list: list(int) = []")
__anon84cbdd653d02(InferenceContext* c) 3109     .SetShapeFn([](InferenceContext* c) {
3110       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3111       ShapeHandle unused, channel;
3112       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3113       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3114       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3115       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3116       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3117       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3118       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3119       // Since activations are not requantized per channel, `min_output`
3120       // and `max_output` are scalars.
3121       c->set_output(1, c->Scalar());
3122       c->set_output(2, c->Scalar());
3123       return OkStatus();
3124     });
3125 
3126 // Fusion of Quantized MatMul and BiasAdd.
3127 REGISTER_OP("QuantizedMatMulWithBias")
3128     .Input("a: T1")
3129     .Input("b: T2")
3130     .Input("bias: Tbias")
3131     .Input("min_a: float")
3132     .Input("max_a: float")
3133     .Input("min_b: float")
3134     .Input("max_b: float")
3135     .Output("out: Toutput")
3136     .Output("min_out: float")
3137     .Output("max_out: float")
3138     .Attr("T1: quantizedtype")
3139     .Attr("T2: quantizedtype")
3140     .Attr("Tbias: {float, qint32}")
3141     .Attr("Toutput: quantizedtype = DT_QINT32")
3142     .Attr("transpose_a: bool = false")
3143     .Attr("transpose_b: bool = false")
3144     .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
__anon84cbdd653e02(InferenceContext* c) 3145     .SetShapeFn([](InferenceContext* c) {
3146       TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3147       ShapeHandle unused;
3148       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3149       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3150       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3151       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3152       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3153       c->set_output(1, c->Scalar());
3154       c->set_output(2, c->Scalar());
3155       return OkStatus();
3156     });
3157 
3158 REGISTER_OP("QuantizedMatMulWithBiasAndRelu")
3159     .Input("a: T1")
3160     .Input("b: T2")
3161     .Input("bias: float")
3162     .Input("min_a: float")
3163     .Input("max_a: float")
3164     .Input("min_b: float")
3165     .Input("max_b: float")
3166     .Output("out: Toutput")
3167     .Output("min_out: float")
3168     .Output("max_out: float")
3169     .Attr("T1: quantizedtype")
3170     .Attr("T2: quantizedtype")
3171     .Attr("Toutput: quantizedtype = DT_QINT32")
3172     .Attr("transpose_a: bool = false")
3173     .Attr("transpose_b: bool = false")
3174     .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
__anon84cbdd653f02(InferenceContext* c) 3175     .SetShapeFn([](InferenceContext* c) {
3176       TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3177       ShapeHandle unused;
3178       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3179       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3180       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3181       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3182       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3183       c->set_output(1, c->Scalar());
3184       c->set_output(2, c->Scalar());
3185       return OkStatus();
3186     });
3187 
3188 REGISTER_OP("QuantizedMatMulWithBiasAndReluAndRequantize")
3189     .Input("a: T1")
3190     .Input("b: T2")
3191     .Input("bias: Tbias")
3192     .Input("min_a: float")
3193     .Input("max_a: float")
3194     .Input("min_b: float")
3195     .Input("max_b: float")
3196     .Input("min_freezed_output: float")
3197     .Input("max_freezed_output: float")
3198     .Output("out: Toutput")
3199     .Output("min_out: float")
3200     .Output("max_out: float")
3201     .Attr("T1: quantizedtype")
3202     .Attr("T2: quantizedtype")
3203     .Attr("Tbias: {float, qint32}")
3204     .Attr("Toutput: quantizedtype = DT_QUINT8")
3205     .Attr("transpose_a: bool = false")
3206     .Attr("transpose_b: bool = false")
3207     .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
__anon84cbdd654002(InferenceContext* c) 3208     .SetShapeFn([](InferenceContext* c) {
3209       TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3210       ShapeHandle unused;
3211       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3212       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3213       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3214       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3215       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3216       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3217       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3218       c->set_output(1, c->Scalar());
3219       c->set_output(2, c->Scalar());
3220       return OkStatus();
3221     });
3222 
3223 REGISTER_OP("QuantizedMatMulWithBiasAndDequantize")
3224     .Input("a: T1")
3225     .Input("b: T2")
3226     .Input("bias: Tbias")
3227     .Input("min_a: float")
3228     .Input("max_a: float")
3229     .Input("min_b: float")
3230     .Input("max_b: float")
3231     .Input("min_freezed_output: float")
3232     .Input("max_freezed_output: float")
3233     .Output("out: Toutput")
3234     .Attr("T1: quantizedtype")
3235     .Attr("T2: quantizedtype")
3236     .Attr("Tbias: {float, qint32}")
3237     .Attr("Toutput: {float}")
3238     .Attr("transpose_a: bool = false")
3239     .Attr("transpose_b: bool = false")
3240     .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
__anon84cbdd654102(InferenceContext* c) 3241     .SetShapeFn([](InferenceContext* c) {
3242       TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3243       ShapeHandle unused;
3244       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3245       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3246       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3247       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3248       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3249       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3250       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3251 
3252       return OkStatus();
3253     });
3254 
3255 REGISTER_OP("QuantizedMatMulWithBiasAndRequantize")
3256     .Input("a: T1")
3257     .Input("b: T2")
3258     .Input("bias: Tbias")
3259     .Input("min_a: float")
3260     .Input("max_a: float")
3261     .Input("min_b: float")
3262     .Input("max_b: float")
3263     .Input("min_freezed_output: float")
3264     .Input("max_freezed_output: float")
3265     .Output("out: Toutput")
3266     .Output("min_out: float")
3267     .Output("max_out: float")
3268     .Attr("T1: quantizedtype")
3269     .Attr("T2: quantizedtype")
3270     .Attr("Tbias: {float, qint32}")
3271     .Attr("Toutput: quantizedtype = DT_QUINT8")
3272     .Attr("transpose_a: bool = false")
3273     .Attr("transpose_b: bool = false")
3274     .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
__anon84cbdd654202(InferenceContext* c) 3275     .SetShapeFn([](InferenceContext* c) {
3276       TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3277       ShapeHandle unused;
3278       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3279       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3280       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3281       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3282       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3283       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3284       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3285       c->set_output(1, c->Scalar());
3286       c->set_output(2, c->Scalar());
3287       return OkStatus();
3288     });
3289 
3290 REGISTER_OP("QuantizedConv2DPerChannel")
3291     .Input("input: Tinput")
3292     .Input("filter: Tfilter")
3293     .Input("min_input: float")
3294     .Input("max_input: float")
3295     .Input("min_filter: float")
3296     .Input("max_filter: float")
3297     .Output("output: out_type")
3298     .Output("min_output: float")
3299     .Output("max_output: float")
3300     .Attr("Tinput: quantizedtype")
3301     .Attr("Tfilter: quantizedtype")
3302     .Attr("out_type: quantizedtype = DT_QINT32")
3303     .Attr("strides: list(int)")
3304     .Attr(GetPaddingAttrString())
3305     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd654302(InferenceContext* c) 3306     .SetShapeFn([](InferenceContext* c) {
3307       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3308       ShapeHandle unused, channel;
3309       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3310       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3311       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel));
3312       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3313       c->set_output(1, channel);
3314       c->set_output(2, channel);
3315       return OkStatus();
3316     });
3317 
3318 REGISTER_OP("QuantizedDepthwiseConv2D")
3319     .Input("input: Tinput")
3320     .Input("filter: Tfilter")
3321     .Input("min_input: float")
3322     .Input("max_input: float")
3323     .Input("min_filter: float")
3324     .Input("max_filter: float")
3325     .Output("output: out_type")
3326     .Output("min_output: float")
3327     .Output("max_output: float")
3328     .Attr("Tinput: quantizedtype")
3329     .Attr("Tfilter: quantizedtype")
3330     .Attr("out_type: quantizedtype = DT_QINT32")
3331     .Attr("strides: list(int)")
3332     .Attr(GetPaddingAttrString())
3333     .Attr("dilations: list(int) = [1, 1, 1, 1]")
3334     .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3335 
3336 REGISTER_OP("QuantizedDepthwiseConv2DWithBias")
3337     .Input("input: Tinput")
3338     .Input("filter: Tfilter")
3339     .Input("bias: float")
3340     .Input("min_input: float")
3341     .Input("max_input: float")
3342     .Input("min_filter: float")
3343     .Input("max_filter: float")
3344     .Output("output: out_type")
3345     .Output("min_output: float")
3346     .Output("max_output: float")
3347     .Attr("Tinput: quantizedtype")
3348     .Attr("Tfilter: quantizedtype")
3349     .Attr("out_type: quantizedtype = DT_QINT32")
3350     .Attr("strides: list(int)")
3351     .Attr(GetPaddingAttrString())
3352     .Attr("dilations: list(int) = [1, 1, 1, 1]")
3353     .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3354 
3355 REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndRelu")
3356     .Input("input: Tinput")
3357     .Input("filter: Tfilter")
3358     .Input("bias: float")
3359     .Input("min_input: float")
3360     .Input("max_input: float")
3361     .Input("min_filter: float")
3362     .Input("max_filter: float")
3363     .Output("output: out_type")
3364     .Output("min_output: float")
3365     .Output("max_output: float")
3366     .Attr("Tinput: quantizedtype")
3367     .Attr("Tfilter: quantizedtype")
3368     .Attr("out_type: quantizedtype = DT_QINT32")
3369     .Attr("strides: list(int)")
3370     .Attr(GetPaddingAttrString())
3371     .Attr("dilations: list(int) = [1, 1, 1, 1]")
3372     .Attr("padding_list: list(int) = []")
3373     .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3374 
3375 REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize")
3376     .Input("input: Tinput")
3377     .Input("filter: Tfilter")
3378     .Input("bias: Tbias")
3379     .Input("min_input: float")
3380     .Input("max_input: float")
3381     .Input("min_filter: float")
3382     .Input("max_filter: float")
3383     .Input("min_freezed_output: float")
3384     .Input("max_freezed_output: float")
3385     .Output("output: out_type")
3386     .Output("min_output: float")
3387     .Output("max_output: float")
3388     .Attr("Tinput: quantizedtype")
3389     .Attr("Tfilter: quantizedtype")
3390     .Attr("Tbias: {float, qint32}")
3391     .Attr("out_type: quantizedtype = DT_QUINT8")
3392     .Attr("strides: list(int)")
3393     .Attr(GetPaddingAttrString())
3394     .Attr("dilations: list(int) = [1, 1, 1, 1]")
3395     .Attr("padding_list: list(int) = []")
3396     .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3397 
3398 REGISTER_OP("IsotonicRegression")
3399     .Input("input: T")
3400     .Output("output: output_dtype")
3401     .Output("segments: int32")
3402     .Attr("T: realnumbertype")
3403     .Attr("output_dtype: {half, bfloat16, float, double} = DT_FLOAT")
__anon84cbdd654402(::tensorflow::shape_inference::InferenceContext* context) 3404     .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* context) {
3405       context->set_output(0, context->input(0));
3406       context->set_output(1, context->input(0));
3407       return OkStatus();
3408     });
3409 
3410 }  // namespace tensorflow
3411