1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <cmath>
18
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/kernel_shape_util.h"
21 #include "tensorflow/core/framework/numeric_op.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/shape_inference.h"
24 #include "tensorflow/core/lib/core/bits.h"
25 #include "tensorflow/core/lib/math/math_util.h"
26 #include "tensorflow/core/util/mirror_pad_mode.h"
27 #include "tensorflow/core/util/padding.h"
28 #include "tensorflow/core/util/tensor_format.h"
29
30 namespace tensorflow {
31
32 using shape_inference::DimensionHandle;
33 using shape_inference::InferenceContext;
34 using shape_inference::ShapeHandle;
35
36 namespace {
37
FractionalPoolShapeFn(InferenceContext * c)38 Status FractionalPoolShapeFn(InferenceContext* c) {
39 ShapeHandle input;
40 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
41
42 std::vector<float> pooling_ratio;
43 TF_RETURN_IF_ERROR(c->GetAttr("pooling_ratio", &pooling_ratio));
44 if (pooling_ratio.size() != 4) {
45 return errors::InvalidArgument(
46 "pooling_ratio field must specify 4 dimensions");
47 }
48 std::vector<DimensionHandle> output_dims;
49 for (int i = 0; i < 4; ++i) {
50 DimensionHandle d = c->Dim(input, i);
51 if (c->ValueKnown(d)) {
52 // This must match the same logic in the kernel function in
53 // core/kernels/fractional_max_pool_op.cc.
54 auto val =
55 static_cast<int64_t>(std::floor(c->Value(d) / pooling_ratio[i]));
56 if (val < 0) {
57 return errors::InvalidArgument("Size computed for dim ", i,
58 " is negative: ", val);
59 }
60 output_dims.push_back(c->MakeDim(val));
61 } else {
62 output_dims.push_back(c->UnknownDim());
63 }
64 }
65
66 c->set_output(0, c->MakeShape(output_dims));
67 c->set_output(1, c->Vector(output_dims[1]));
68 c->set_output(2, c->Vector(output_dims[2]));
69 return OkStatus();
70 }
71
72 } // namespace
73
74 // --------------------------------------------------------------------------
75
76 REGISTER_OP("AvgPool")
77 .Input("value: T")
78 .Output("output: T")
79 .Attr("ksize: list(int) >= 4")
80 .Attr("strides: list(int) >= 4")
81 .Attr(GetPaddingAttrString())
82 .Attr(GetConvnetDataFormatAttrString())
83 .Attr("T: {half, bfloat16, float, double}")
84 .SetShapeFn(shape_inference::AvgPoolShape);
85
86 REGISTER_OP("AvgPoolGrad")
87 .Input("orig_input_shape: int32")
88 .Input("grad: T")
89 .Output("output: T")
90 .Attr("ksize: list(int) >= 4")
91 .Attr("strides: list(int) >= 4")
92 .Attr(GetPaddingAttrString())
93 .Attr(GetConvnetDataFormatAttrString())
94 .Attr("T: {half, bfloat16, float, double}")
95 .SetShapeFn(shape_inference::AvgPoolGradShape);
96
97 // --------------------------------------------------------------------------
98
99 REGISTER_OP("BatchNormWithGlobalNormalization")
100 .Input("t: T")
101 .Input("m: T")
102 .Input("v: T")
103 .Input("beta: T")
104 .Input("gamma: T")
105 .Output("result: T")
106 .Attr("T: numbertype")
107 .Attr("variance_epsilon: float")
108 .Attr("scale_after_normalization: bool")
109 .Deprecated(9, "Use tf.nn.batch_normalization()")
__anon84cbdd650202(InferenceContext* c) 110 .SetShapeFn([](InferenceContext* c) {
111 ShapeHandle input;
112 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
113
114 DimensionHandle last_dim = c->Dim(input, 3);
115 for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma
116 ShapeHandle vec;
117 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
118 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
119 }
120
121 ShapeHandle out;
122 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
123 c->set_output(0, out);
124 return OkStatus();
125 });
126
127 REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
128 .Input("t: T")
129 .Input("m: T")
130 .Input("v: T")
131 .Input("gamma: T")
132 .Input("backprop: T")
133 .Output("dx: T")
134 .Output("dm: T")
135 .Output("dv: T")
136 .Output("db: T")
137 .Output("dg: T")
138 .Attr("T: numbertype")
139 .Attr("variance_epsilon: float")
140 .Attr("scale_after_normalization: bool")
141 .Deprecated(9, "Use tf.nn.batch_normalization()")
__anon84cbdd650302(InferenceContext* c) 142 .SetShapeFn([](InferenceContext* c) {
143 ShapeHandle input;
144 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
145 TF_RETURN_IF_ERROR(
146 c->Merge(input, c->input(4), &input)); // with backprop
147
148 DimensionHandle last_dim = c->Dim(input, 3);
149 for (int i = 1; i < 4; ++i) { // covers m, v, gamma
150 ShapeHandle vec;
151 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
152 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
153 }
154
155 ShapeHandle dx;
156 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx));
157 c->set_output(0, dx);
158
159 ShapeHandle vector_shape = c->Vector(last_dim);
160 c->set_output(1, vector_shape);
161 c->set_output(2, vector_shape);
162 c->set_output(3, vector_shape);
163 c->set_output(4, vector_shape);
164 return OkStatus();
165 });
166
167 // --------------------------------------------------------------------------
168
169 REGISTER_OP("FusedBatchNorm")
170 .Input("x: T")
171 .Input("scale: T")
172 .Input("offset: T")
173 .Input("mean: T")
174 .Input("variance: T")
175 .Output("y: T")
176 .Output("batch_mean: T")
177 .Output("batch_variance: T")
178 .Output("reserve_space_1: T")
179 .Output("reserve_space_2: T")
180 .Attr("T: {float}")
181 .Attr("epsilon: float = 0.0001")
182 .Attr("exponential_avg_factor: float = 1.0")
183 .Attr(GetConvnetDataFormatAttrString())
184 .Attr("is_training: bool = true")
185 .SetShapeFn(shape_inference::FusedBatchNormShape);
186
187 REGISTER_OP("FusedBatchNormV2")
188 .Input("x: T")
189 .Input("scale: U")
190 .Input("offset: U")
191 .Input("mean: U")
192 .Input("variance: U")
193 .Output("y: T")
194 .Output("batch_mean: U")
195 .Output("batch_variance: U")
196 .Output("reserve_space_1: U")
197 .Output("reserve_space_2: U")
198 .Attr("T: {half, bfloat16, float}")
199 .Attr("U: {float}")
200 .Attr("epsilon: float = 0.0001")
201 .Attr("exponential_avg_factor: float = 1.0")
202 .Attr(GetConvnetDataFormatAttrString())
203 .Attr("is_training: bool = true")
204 .SetShapeFn(shape_inference::FusedBatchNormShape);
205
206 REGISTER_OP("FusedBatchNormV3")
207 .Input("x: T")
208 .Input("scale: U")
209 .Input("offset: U")
210 .Input("mean: U")
211 .Input("variance: U")
212 .Output("y: T")
213 .Output("batch_mean: U")
214 .Output("batch_variance: U")
215 .Output("reserve_space_1: U")
216 .Output("reserve_space_2: U")
217 .Output("reserve_space_3: U")
218 .Attr("T: {half, bfloat16, float}")
219 .Attr("U: {bfloat16, float}")
220 .Attr("epsilon: float = 0.0001")
221 .Attr("exponential_avg_factor: float = 1.0")
222 .Attr(GetConvnetDataFormat2D3DAttrString())
223 .Attr("is_training: bool = true")
224 .SetShapeFn(shape_inference::FusedBatchNormV3Shape);
225
226 REGISTER_OP("_FusedBatchNormEx")
227 .Input("x: T")
228 .Input("scale: U")
229 .Input("offset: U")
230 .Input("mean: U")
231 .Input("variance: U")
232 .Input("side_input: num_side_inputs * T")
233 .Output("y: T")
234 .Output("batch_mean: U")
235 .Output("batch_variance: U")
236 .Output("reserve_space_1: U")
237 .Output("reserve_space_2: U")
238 .Output("reserve_space_3: U")
239 .Attr("T: {half, float, bfloat16}")
240 .Attr("U: {float}")
241 .Attr("epsilon: float = 0.0001")
242 .Attr("exponential_avg_factor: float = 1.0")
243 .Attr("num_side_inputs: int >= 0 = 0")
244 .Attr("activation_mode: string = \"Identity\"")
245 .Attr(GetConvnetDataFormatAttrString())
246 .Attr("is_training: bool = true")
247 .SetShapeFn(shape_inference::FusedBatchNormExShape)
248 .Doc(R"doc(
249 Internal FusedBatchNorm operation: reserved for internal use.
250
251 Do not invoke this operator directly in Python. A fusion optimization is
252 expected to create these operators.
253 )doc");
254
255 REGISTER_OP("FusedBatchNormGrad")
256 .Input("y_backprop: T")
257 .Input("x: T")
258 .Input("scale: T")
259 .Input("reserve_space_1: T")
260 .Input("reserve_space_2: T")
261 .Output("x_backprop: T")
262 .Output("scale_backprop: T")
263 .Output("offset_backprop: T")
264 .Output("reserve_space_3: T")
265 .Output("reserve_space_4: T")
266 .Attr("T: {float}")
267 .Attr("epsilon: float = 0.0001")
268 .Attr(GetConvnetDataFormatAttrString())
269 .Attr("is_training: bool = true")
270 .SetShapeFn(shape_inference::FusedBatchNormGradShape);
271
272 REGISTER_OP("FusedBatchNormGradV2")
273 .Input("y_backprop: T")
274 .Input("x: T")
275 .Input("scale: float")
276 .Input("reserve_space_1: U")
277 .Input("reserve_space_2: U")
278 .Output("x_backprop: T")
279 .Output("scale_backprop: U")
280 .Output("offset_backprop: U")
281 .Output("reserve_space_3: U")
282 .Output("reserve_space_4: U")
283 .Attr("T: {half, bfloat16, float}")
284 .Attr("U: {float}")
285 .Attr("epsilon: float = 0.0001")
286 .Attr(GetConvnetDataFormatAttrString())
287 .Attr("is_training: bool = true")
288 .SetShapeFn(shape_inference::FusedBatchNormGradShape);
289
290 REGISTER_OP("FusedBatchNormGradV3")
291 .Input("y_backprop: T")
292 .Input("x: T")
293 .Input("scale: float")
294 .Input("reserve_space_1: U")
295 .Input("reserve_space_2: U")
296 .Input("reserve_space_3: U")
297 .Output("x_backprop: T")
298 .Output("scale_backprop: U")
299 .Output("offset_backprop: U")
300 .Output("reserve_space_4: U")
301 .Output("reserve_space_5: U")
302 .Attr("T: {half, bfloat16, float}")
303 .Attr("U: {float}")
304 .Attr("epsilon: float = 0.0001")
305 .Attr(GetConvnetDataFormat2D3DAttrString())
306 .Attr("is_training: bool = true")
307 .SetShapeFn(shape_inference::FusedBatchNormGradShape);
308
309 REGISTER_OP("_FusedBatchNormGradEx")
310 .Input("y_backprop: T")
311 .Input("x: T")
312 .Input("scale: float")
313 .Input("reserve_space_1: U")
314 .Input("reserve_space_2: U")
315 .Input("reserve_space_3: U")
316 .Input("offset: float")
317 .Input("y: T")
318 .Output("x_backprop: T")
319 .Output("scale_backprop: U")
320 .Output("offset_backprop: U")
321 .Output("reserve_space_4: U")
322 .Output("reserve_space_5: U")
323 .Output("side_input_backprop: num_side_inputs * T")
324 .Attr("T: {half, float}")
325 .Attr("U: {float}")
326 .Attr("epsilon: float = 0.0001")
327 .Attr("num_side_inputs: int >= 0 = 0")
328 .Attr("activation_mode: string = \"Identity\"")
329 .Attr(GetConvnetDataFormat2D3DAttrString())
330 .Attr("is_training: bool = true")
331 .SetShapeFn(shape_inference::FusedBatchNormGradExShape)
332 .Doc(R"doc(
333 Internal FusedBatchNormGrad operation: reserved for internal use.
334
335 Do not invoke this operator directly in Python. A fusion optimization is
336 expected to create these operators.
337 )doc");
338 // --------------------------------------------------------------------------
339
340 REGISTER_OP("BiasAdd")
341 .Attr("T: numbertype")
342 .Input("value: T")
343 .Input("bias: T")
344 .Attr(GetConvnetDataFormatAttrString())
345 .Output("output: T")
346 .SetShapeFn(shape_inference::BiasAddShape);
347 // --------------------------------------------------------------------------
348
349 REGISTER_OP("BiasAddGrad")
350 .Attr("T: numbertype")
351 .Input("out_backprop: T")
352 .Attr(GetConvnetDataFormatAttrString())
353 .Output("output: T")
354 .SetShapeFn(shape_inference::BiasAddGradShape);
355 // --------------------------------------------------------------------------
356
357 REGISTER_OP("BiasAddV1")
358 .Attr("T: numbertype")
359 .Input("value: T")
360 .Input("bias: T")
361 .Output("output: T")
362 .SetShapeFn(shape_inference::BiasAddShape);
363 // --------------------------------------------------------------------------
364
365 REGISTER_OP("Conv2D")
366 .Input("input: T")
367 .Input("filter: T")
368 .Output("output: T")
369 .Attr("T: {half, bfloat16, float, double, int32}")
370 .Attr("strides: list(int)")
371 .Attr("use_cudnn_on_gpu: bool = true")
372 .Attr(GetPaddingAttrStringWithExplicit())
373 .Attr(GetExplicitPaddingsAttrString())
374 .Attr(GetConvnetDataFormatAttrString())
375 .Attr("dilations: list(int) = [1, 1, 1, 1]")
376 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding);
377
378 REGISTER_OP("Conv2DBackpropInput")
379 .Input("input_sizes: int32")
380 .Input("filter: T")
381 .Input("out_backprop: T")
382 .Output("output: T")
383 .Attr("T: {half, bfloat16, float, double, int32}")
384 .Attr("strides: list(int)")
385 .Attr("use_cudnn_on_gpu: bool = true")
386 .Attr(GetPaddingAttrStringWithExplicit())
387 .Attr(GetExplicitPaddingsAttrString())
388 .Attr(GetConvnetDataFormatAttrString())
389 .Attr("dilations: list(int) = [1, 1, 1, 1]")
390 .SetShapeFn(shape_inference::Conv2DBackpropInputShape);
391
392 // TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
393 // more general string attribute ('kernel_impl'?) that can be used to
394 // select among several possible implementations.
395 REGISTER_OP("Conv2DBackpropFilter")
396 .Input("input: T")
397 .Input("filter_sizes: int32")
398 .Input("out_backprop: T")
399 .Output("output: T")
400 .Attr("T: {half, bfloat16, float, double}")
401 .Attr("strides: list(int)")
402 .Attr("use_cudnn_on_gpu: bool = true")
403 .Attr(GetPaddingAttrStringWithExplicit())
404 .Attr(GetExplicitPaddingsAttrString())
405 .Attr(GetConvnetDataFormatAttrString())
406 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd650402(InferenceContext* c) 407 .SetShapeFn([](InferenceContext* c) {
408 ShapeHandle s;
409 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
410 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
411 c->set_output(0, s);
412 return OkStatus();
413 });
414
415 REGISTER_OP("_FusedConv2D")
416 .Input("input: T")
417 .Input("filter: T")
418 .Input("args: num_args * T")
419 .Output("output: T")
420 .Attr("T: {half, float, double}")
421 .Attr("num_args: int >= 0")
422 .Attr("strides: list(int)")
423 .Attr(GetPaddingAttrStringWithExplicit())
424 .Attr(GetExplicitPaddingsAttrString())
425 .Attr(GetConvnetDataFormatAttrString())
426 .Attr("dilations: list(int) = [1, 1, 1, 1]")
427 .Attr("use_cudnn_on_gpu: bool = true")
428 .Attr("fused_ops: list(string) = []")
429 // Attributes for the FusedBatchNorm ------------------------------------ //
430 .Attr("epsilon: float = 0.0001")
431 // Attributes for the LeakyRelu ----------------------------------------- //
432 .Attr("leakyrelu_alpha: float = 0.2")
433 // ---------------------------------------------------------------------- //
434 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
435 .Doc(R"doc(
436 Performs a convolution followed by a specified series of operations.
437
438 The inputs to the convolution are `input` and `filter`. The series of operations
439 that follows is specified by the `fused_ops` attribute, which is a list of TF op
440 names specified as strings (e.g. "Relu"). They are performed in order, where the
441 (first) input to each op is the output of the preceding op. The first input and
442 the output of each fused_op must be of type T.
443
444 Currently supported fused_op combinations are: [X] and [X,A], where X is one of
445 {"BiasAdd","FusedBatchNorm"} and A is one of {"Elu","Relu","Relu6"}.
446
447 * The first input to op X is the Conv2D result, and the additional input(s) to X
448 are specified by `args`.
449 * If there is an op A specified, the output of op X is the input to op A, and op
450 A produces the _FusedConv2D output. Otherwise, op X produces the _FusedConv2D
451 output.
452
453 *NOTE*: Do not invoke this operator directly in Python. Grappler is expected to
454 create these operators.
455 )doc");
456
457 namespace {
458
CommonFusedConvCalculations(InferenceContext * c,bool has_resize)459 Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) {
460 ShapeHandle input;
461 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
462
463 ShapeHandle resized = input;
464 int paddings_index = 1;
465 int filter_index = 2;
466 if (has_resize) {
467 paddings_index = 2;
468 filter_index = 3;
469
470 ShapeHandle unused_size;
471 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused_size));
472
473 const Tensor* size = c->input_tensor(1);
474 DimensionHandle new_height = c->UnknownDim();
475 DimensionHandle new_width = c->UnknownDim();
476 if (size != nullptr) {
477 new_height = c->MakeDim(size->flat<int32>()(0));
478 new_width = c->MakeDim(size->flat<int32>()(1));
479 }
480 TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 1, new_height, &resized));
481 TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 2, new_width, &resized));
482 }
483
484 ShapeHandle paddings;
485 TF_RETURN_IF_ERROR(c->WithRank(c->input(paddings_index), 2, &paddings));
486 TF_RETURN_IF_ERROR(
487 c->WithRank(resized, c->Value(c->Dim(paddings, 0)), &resized));
488 TF_RETURN_IF_ERROR(
489 c->Merge(paddings, c->Matrix(c->Rank(resized), 2), &paddings));
490
491 const Tensor* paddings_t = c->input_tensor(paddings_index);
492 ShapeHandle padded;
493 if (paddings_t != nullptr) {
494 std::vector<DimensionHandle> output_dims;
495 for (int i = 0; i < 4; ++i) {
496 DimensionHandle dim = c->Dim(resized, i);
497 int64_t p0 = static_cast<int64_t>(paddings_t->matrix<int32>()(i, 0));
498 int64_t p1 = static_cast<int64_t>(paddings_t->matrix<int32>()(i, 1));
499 if (p0 < 0 || p1 < 0) {
500 return errors::InvalidArgument("Paddings must be non-negative");
501 }
502
503 TF_RETURN_IF_ERROR(c->Add(dim, p0 + p1, &dim));
504 output_dims.push_back(dim);
505 }
506 padded = c->MakeShape(output_dims);
507 } else {
508 padded = c->UnknownShapeOfRank(4);
509 }
510
511 // Work out the convolution's effect with 'padded' as the input.
512 ShapeHandle filter;
513 TF_RETURN_IF_ERROR(c->WithRank(c->input(filter_index), 4, &filter));
514 std::vector<int32> strides;
515 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
516 if (strides.size() != 4) {
517 return errors::InvalidArgument(
518 "Operation requires the stride attribute to contain 4 values, but ",
519 "got: ", strides.size());
520 }
521
522 int32_t stride_rows = strides[1];
523 int32_t stride_cols = strides[2];
524
525 DimensionHandle batch_size_dim = c->Dim(padded, 0);
526 DimensionHandle in_rows_dim = c->Dim(padded, 1);
527 DimensionHandle in_cols_dim = c->Dim(padded, 2);
528 DimensionHandle filter_rows_dim = c->Dim(filter, 0);
529 DimensionHandle filter_cols_dim = c->Dim(filter, 1);
530 DimensionHandle output_depth_dim = c->Dim(filter, 3);
531
532 DimensionHandle unused;
533 TF_RETURN_IF_ERROR(c->Merge(c->Dim(padded, 3), c->Dim(filter, 2), &unused));
534
535 Padding padding;
536 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
537
538 DimensionHandle output_rows, output_cols;
539 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
540 c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
541 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
542 c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
543
544 ShapeHandle output_shape = c->MakeShape(
545 {batch_size_dim, output_rows, output_cols, output_depth_dim});
546 c->set_output(0, output_shape);
547 return OkStatus();
548 }
549
550 } // namespace
551
552 REGISTER_OP("DataFormatDimMap")
553 .Input("x: T")
554 .Output("y: T")
555 .Attr("T: {int32, int64} = DT_INT32")
556 .Attr("src_format: string = 'NHWC'")
557 .Attr("dst_format: string = 'NCHW'")
558 .SetShapeFn(shape_inference::UnchangedShape);
559
560 REGISTER_OP("DataFormatVecPermute")
561 .Input("x: T")
562 .Output("y: T")
563 .Attr("T: {int32, int64} = DT_INT32")
564 .Attr("src_format: string = 'NHWC'")
565 .Attr("dst_format: string = 'NCHW'")
566 .SetShapeFn(shape_inference::UnchangedShape);
567
568 REGISTER_OP("FusedResizeAndPadConv2D")
569 .Input("input: T")
570 .Input("size: int32")
571 .Input("paddings: int32")
572 .Input("filter: T")
573 .Output("output: T")
574 .Attr("T: {half, float, double}")
575 .Attr("resize_align_corners: bool = false")
576 .Attr(GetMirrorPadModeAttrString())
577 .Attr("strides: list(int)")
578 .Attr(GetPaddingAttrString())
__anon84cbdd650602(InferenceContext* c) 579 .SetShapeFn([](InferenceContext* c) {
580 return CommonFusedConvCalculations(c, true /* has_resize */);
581 });
582
583 REGISTER_OP("FusedPadConv2D")
584 .Input("input: T")
585 .Input("paddings: int32")
586 .Input("filter: T")
587 .Output("output: T")
588 .Attr("T: {half, float, double}")
589 .Attr(GetMirrorPadModeAttrString())
590 .Attr("strides: list(int)")
591 .Attr(GetPaddingAttrString())
__anon84cbdd650702(InferenceContext* c) 592 .SetShapeFn([](InferenceContext* c) {
593 return CommonFusedConvCalculations(c, false /* has_resize */);
594 });
595
596 // --------------------------------------------------------------------------
597
598 REGISTER_OP("DepthwiseConv2dNative")
599 .Input("input: T")
600 .Input("filter: T")
601 .Output("output: T")
602 .Attr("T: {half, bfloat16, float, double}")
603 .Attr("strides: list(int)")
604 .Attr(GetPaddingAttrStringWithExplicit())
605 .Attr(GetExplicitPaddingsAttrString())
606 .Attr(GetConvnetDataFormatAttrString())
607 .Attr("dilations: list(int) = [1, 1, 1, 1]")
608 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShapeWithExplicitPadding);
609
610 REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
611 .Input("input_sizes: int32")
612 .Input("filter: T")
613 .Input("out_backprop: T")
614 .Output("output: T")
615 .Attr("T: {half, bfloat16, float, double}")
616 .Attr("strides: list(int)")
617 .Attr(GetPaddingAttrStringWithExplicit())
618 .Attr(GetExplicitPaddingsAttrString())
619 .Attr(GetConvnetDataFormatAttrString())
620 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd650802(InferenceContext* c) 621 .SetShapeFn([](InferenceContext* c) {
622 ShapeHandle s;
623 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
624 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
625 c->set_output(0, s);
626 return OkStatus();
627 });
628
629 REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
630 .Input("input: T")
631 .Input("filter_sizes: int32")
632 .Input("out_backprop: T")
633 .Output("output: T")
634 .Attr("T: {half, bfloat16, float, double}")
635 .Attr("strides: list(int)")
636 .Attr(GetPaddingAttrStringWithExplicit())
637 .Attr(GetExplicitPaddingsAttrString())
638 .Attr(GetConvnetDataFormatAttrString())
639 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd650902(InferenceContext* c) 640 .SetShapeFn([](InferenceContext* c) {
641 ShapeHandle s;
642 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
643 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
644 c->set_output(0, s);
645 return OkStatus();
646 });
647
648 REGISTER_OP("_FusedDepthwiseConv2dNative")
649 .Input("input: T")
650 .Input("filter: T")
651 .Input("args: num_args * T")
652 .Output("output: T")
653 .Attr("T: {half, bfloat16, float, double}")
654 .Attr("num_args: int >= 0")
655 .Attr("strides: list(int)")
656 .Attr(GetPaddingAttrString())
657 .Attr(GetConvnetDataFormatAttrString())
658 .Attr("dilations: list(int) = [1, 1, 1, 1]")
659 .Attr("fused_ops: list(string) = []")
660 // Attributes for the FusedBatchNorm ------------------------------------ //
661 .Attr("epsilon: float = 0.0001")
662 // Attributes for the LeakyRelu ----------------------------------------- //
663 .Attr("leakyrelu_alpha: float = 0.2")
664 // ---------------------------------------------------------------------- //
665 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
666
667 // --------------------------------------------------------------------------
668
669 REGISTER_OP("Conv3D")
670 .Input("input: T")
671 .Input("filter: T")
672 .Output("output: T")
673 .Attr("T: {half, bfloat16, float, double}")
674 .Attr("strides: list(int) >= 5")
675 .Attr(GetPaddingAttrString())
676 .Attr(GetConvnet3dDataFormatAttrString())
677 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
678 .SetShapeFn(shape_inference::Conv3DShape);
679
680 REGISTER_OP("Conv3DBackpropInput")
681 .Input("input: T")
682 .Input("filter: T")
683 .Input("out_backprop: T")
684 .Output("output: T")
685 .Attr("T: {half, float, double}")
686 .Attr("strides: list(int) >= 5")
687 .Attr(GetPaddingAttrString())
688 .Deprecated(10, "Use Conv3DBackpropInputV2")
689 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon84cbdd650a02(InferenceContext* c) 690 .SetShapeFn([](InferenceContext* c) {
691 return UnchangedShapeWithRank(c, 5);
692 });
693
694 REGISTER_OP("Conv3DBackpropFilter")
695 .Input("input: T")
696 .Input("filter: T")
697 .Input("out_backprop: T")
698 .Output("output: T")
699 .Attr("T: {half, float, double}")
700 .Attr("strides: list(int) >= 5")
701 .Attr(GetPaddingAttrString())
702 .Deprecated(10, "Use Conv3DBackpropFilterV2")
703 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon84cbdd650b02(InferenceContext* c) 704 .SetShapeFn([](InferenceContext* c) {
705 ShapeHandle out;
706 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &out));
707 c->set_output(0, out);
708 return OkStatus();
709 });
710
711 REGISTER_OP("Conv3DBackpropInputV2")
712 .Input("input_sizes: Tshape")
713 .Input("filter: T")
714 .Input("out_backprop: T")
715 .Output("output: T")
716 .Attr("T: {half, bfloat16, float, double}")
717 .Attr("strides: list(int) >= 5")
718 .Attr(GetPaddingAttrString())
719 .Attr(GetConvnet3dDataFormatAttrString())
720 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
721 .Attr("Tshape: {int32, int64} = DT_INT32")
__anon84cbdd650c02(InferenceContext* c) 722 .SetShapeFn([](InferenceContext* c) {
723 ShapeHandle s;
724 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
725 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
726 c->set_output(0, s);
727 return OkStatus();
728 });
729
730 REGISTER_OP("Conv3DBackpropFilterV2")
731 .Input("input: T")
732 .Input("filter_sizes: int32")
733 .Input("out_backprop: T")
734 .Output("output: T")
735 .Attr("T: {half, bfloat16, float, double}")
736 .Attr("strides: list(int) >= 5")
737 .Attr(GetPaddingAttrString())
738 .Attr(GetConvnet3dDataFormatAttrString())
739 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon84cbdd650d02(InferenceContext* c) 740 .SetShapeFn([](InferenceContext* c) {
741 ShapeHandle s;
742 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
743 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
744 c->set_output(0, s);
745 return OkStatus();
746 });
747
748 // --------------------------------------------------------------------------
749
750 REGISTER_OP("AvgPool3D")
751 .Input("input: T")
752 .Output("output: T")
753 .Attr("ksize: list(int) >= 5")
754 .Attr("strides: list(int) >= 5")
755 .Attr(GetPaddingAttrString())
756 .Attr(GetConvnet3dDataFormatAttrString())
757 .Attr("T: {half, bfloat16, float, double}")
758 .SetShapeFn(shape_inference::Pool3DShape);
759
760 REGISTER_OP("AvgPool3DGrad")
761 .Input("orig_input_shape: int32")
762 .Input("grad: T")
763 .Output("output: T")
764 .Attr("ksize: list(int) >= 5")
765 .Attr("strides: list(int) >= 5")
766 .Attr(GetPaddingAttrString())
767 .Attr(GetConvnet3dDataFormatAttrString())
768 .Attr("T: {half, bfloat16, float, double}")
769 .SetShapeFn(shape_inference::AvgPool3DGradShape);
770
771 // --------------------------------------------------------------------------
772
773 REGISTER_OP("MaxPool3D")
774 .Input("input: T")
775 .Output("output: T")
776 .Attr("ksize: list(int) >= 5")
777 .Attr("strides: list(int) >= 5")
778 .Attr(GetPaddingAttrString())
779 .Attr(GetConvnet3dDataFormatAttrString())
780 .Attr("T: {half, bfloat16, float}")
781 .SetShapeFn(shape_inference::Pool3DShape);
782
783 REGISTER_OP("MaxPool3DGrad")
784 .Input("orig_input: TInput")
785 .Input("orig_output: TInput")
786 .Input("grad: T")
787 .Output("output: T")
788 .Attr("ksize: list(int) >= 5")
789 .Attr("strides: list(int) >= 5")
790 .Attr(GetPaddingAttrString())
791 .Attr(GetConvnet3dDataFormatAttrString())
792 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
793 .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
794 .SetShapeFn(shape_inference::MaxPool3DGradShape);
795
796 REGISTER_OP("MaxPool3DGradGrad")
797 .Input("orig_input: T")
798 .Input("orig_output: T")
799 .Input("grad: T")
800 .Output("output: T")
801 .Attr("ksize: list(int) >= 5 ")
802 .Attr("strides: list(int) >= 5")
803 .Attr(GetPaddingAttrString())
804 .Attr(GetConvnet3dDataFormatAttrString())
805 .Attr("T: realnumbertype")
__anon84cbdd650e02(InferenceContext* c) 806 .SetShapeFn([](InferenceContext* c) {
807 TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c));
808 ShapeHandle unused;
809 // Validate 'orig_input' is the same shape as 'grad'
810 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
811 // Validate 'orig_output' is same shape as 'output'
812 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
813 return OkStatus();
814 });
815
816 // --------------------------------------------------------------------------
817
818 REGISTER_OP("L2Loss")
819 .Input("t: T")
820 .Output("output: T")
821 .Attr("T: {half, bfloat16, float, double}")
822 .SetShapeFn(shape_inference::ScalarShape);
823
824 // --------------------------------------------------------------------------
825
826 REGISTER_OP("LRN")
827 .Input("input: T")
828 .Output("output: T")
829 .Attr("depth_radius: int = 5")
830 .Attr("bias: float = 1.0")
831 .Attr("alpha: float = 1.0")
832 .Attr("beta: float = 0.5")
833 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
__anon84cbdd650f02(InferenceContext* c) 834 .SetShapeFn([](InferenceContext* c) {
835 return UnchangedShapeWithRank(c, 4);
836 });
837
838 REGISTER_OP("LRNGrad")
839 .Input("input_grads: T")
840 .Input("input_image: T")
841 .Input("output_image: T")
842 .Output("output: T")
843 .Attr("depth_radius: int = 5")
844 .Attr("bias: float = 1.0")
845 .Attr("alpha: float = 1.0")
846 .Attr("beta: float = 0.5")
847 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
__anon84cbdd651002(InferenceContext* c) 848 .SetShapeFn([](InferenceContext* c) {
849 ShapeHandle s;
850 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
851 TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image
852 TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image
853 c->set_output(0, s);
854 return OkStatus();
855 });
856
857 // --------------------------------------------------------------------------
858
859 REGISTER_OP("MaxPool")
860 .Attr(
861 "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
862 "uint16, qint8} = DT_FLOAT")
863 .Attr("ksize: list(int) >= 4")
864 .Attr("strides: list(int) >= 4")
865 .Attr(GetPaddingAttrStringWithExplicit())
866 .Attr(GetExplicitPaddingsAttrString())
867 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
868 .Input("input: T")
869 .Output("output: T")
870 .SetShapeFn(shape_inference::MaxPoolShapeWithExplicitPadding);
871
872 REGISTER_OP("MaxPoolV2")
873 .Attr(
874 "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
875 "uint16, qint8} = DT_FLOAT")
876 .Attr(GetPaddingAttrString())
877 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
878 .Input("input: T")
879 .Input("ksize: int32")
880 .Input("strides: int32")
881 .Output("output: T")
__anon84cbdd651102(InferenceContext* c) 882 .SetShapeFn([](InferenceContext* c) {
883 TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3));
884 return OkStatus();
885 });
886
887 REGISTER_OP("MaxPoolGrad")
888 .Attr("ksize: list(int) >= 4")
889 .Attr("strides: list(int) >= 4")
890 .Attr(GetPaddingAttrStringWithExplicit())
891 .Attr(GetExplicitPaddingsAttrString())
892 .Attr(GetConvnetDataFormatAttrString())
893 .Input("orig_input: T")
894 .Input("orig_output: T")
895 .Input("grad: T")
896 .Output("output: T")
897 .Attr("T: realnumbertype = DT_FLOAT")
898 .SetShapeFn(shape_inference::MaxPoolGradShape);
899
900 REGISTER_OP("MaxPoolGradV2")
901 .Attr(GetPaddingAttrString())
902 .Attr(GetConvnetDataFormatAttrString())
903 .Input("orig_input: T")
904 .Input("orig_output: T")
905 .Input("grad: T")
906 .Input("ksize: int32")
907 .Input("strides: int32")
908 .Output("output: T")
909 .Attr("T: realnumbertype = DT_FLOAT")
910 .SetShapeFn(shape_inference::MaxPoolGradShape);
911
912 // TODO(b/150813181): Implement explicit padding.
913 REGISTER_OP("MaxPoolGradGrad")
914 .Attr("ksize: list(int) >= 4")
915 .Attr("strides: list(int) >= 4")
916 .Attr(GetPaddingAttrString())
917 .Attr(GetConvnetDataFormatAttrString())
918 .Input("orig_input: T")
919 .Input("orig_output: T")
920 .Input("grad: T")
921 .Output("output: T")
922 .Attr("T: realnumbertype")
__anon84cbdd651202(InferenceContext* c) 923 .SetShapeFn([](InferenceContext* c) {
924 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
925 ShapeHandle unused;
926 // Validate 'orig_input' is the same shape as 'grad'
927 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
928 // Validate 'orig_output' is same shape as 'output'
929 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
930 return OkStatus();
931 });
932
933 REGISTER_OP("MaxPoolGradGradV2")
934 .Attr(GetPaddingAttrString())
935 .Attr(GetConvnetDataFormatAttrString())
936 .Input("orig_input: T")
937 .Input("orig_output: T")
938 .Input("grad: T")
939 .Input("ksize: int32")
940 .Input("strides: int32")
941 .Output("output: T")
942 .Attr("T: realnumbertype")
__anon84cbdd651302(InferenceContext* c) 943 .SetShapeFn([](InferenceContext* c) {
944 TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5));
945 ShapeHandle unused;
946 // Validate 'orig_input' is the same shape as 'grad'
947 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
948 // Validate 'orig_output' is same shape as 'output'
949 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
950 return OkStatus();
951 });
952
953 REGISTER_OP("MaxPoolWithArgmax")
954 .Attr("ksize: list(int) >= 4")
955 .Attr("strides: list(int) >= 4")
956 .Attr("Targmax: {int32, int64} = DT_INT64")
957 .Attr(GetPaddingAttrString())
958 .Attr("include_batch_in_index: bool = false")
959 .Input("input: T")
960 .Output("output: T")
961 .Output("argmax: Targmax")
962 .Attr("T: realnumbertype")
__anon84cbdd651402(InferenceContext* c) 963 .SetShapeFn([](InferenceContext* c) {
964 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
965 c->set_output(1, c->output(0));
966 return OkStatus();
967 });
968
969 REGISTER_OP("MaxPoolGradWithArgmax")
970 .Attr("ksize: list(int) >= 4")
971 .Attr("strides: list(int) >= 4")
972 .Attr(GetPaddingAttrString())
973 .Attr("include_batch_in_index: bool = false")
974 .Attr("Targmax: {int32, int64}")
975 .Input("input: T")
976 .Input("grad: T")
977 .Input("argmax: Targmax")
978 .Output("output: T")
979 .Attr("T: realnumbertype")
__anon84cbdd651502(InferenceContext* c) 980 .SetShapeFn([](InferenceContext* c) {
981 return UnchangedShapeWithRank(c, 4);
982 });
983
984 REGISTER_OP("MaxPoolGradGradWithArgmax")
985 .Attr("ksize: list(int) >= 4")
986 .Attr("strides: list(int) >= 4")
987 .Attr(GetPaddingAttrString())
988 .Attr("include_batch_in_index: bool = false")
989 .Attr("Targmax: {int32, int64}")
990 .Input("input: T")
991 .Input("grad: T")
992 .Input("argmax: Targmax")
993 .Output("output: T")
994 .Attr("T: realnumbertype")
__anon84cbdd651602(InferenceContext* c) 995 .SetShapeFn([](InferenceContext* c) {
996 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
997 ShapeHandle unused;
998 // Validate 'orig_input' is the same shape as 'grad'
999 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &unused));
1000 // Validate 'argmax' is same shape as 'output'
1001 TF_RETURN_IF_ERROR(c->Merge(c->input(2), c->output(0), &unused));
1002 return OkStatus();
1003 });
1004
1005 // --------------------------------------------------------------------------
1006
1007 REGISTER_OP("Dilation2D")
1008 .Input("input: T")
1009 .Input("filter: T")
1010 .Output("output: T")
1011 .Attr("T: realnumbertype")
1012 .Attr("strides: list(int) >= 4")
1013 .Attr("rates: list(int) >= 4")
1014 .Attr(GetPaddingAttrString())
__anon84cbdd651702(InferenceContext* c) 1015 .SetShapeFn([](InferenceContext* c) {
1016 ShapeHandle input_shape;
1017 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1018 ShapeHandle filter_shape;
1019 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &filter_shape));
1020
1021 std::vector<int32> strides;
1022 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1023 if (strides.size() != 4) {
1024 return errors::InvalidArgument(
1025 "Dilation2D requires the stride attribute to contain 4 values, but "
1026 "got: ",
1027 strides.size());
1028 }
1029
1030 std::vector<int32> rates;
1031 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
1032 if (rates.size() != 4) {
1033 return errors::InvalidArgument(
1034 "Dilation2D requires the rates attribute to contain 4 values, but "
1035 "got: ",
1036 rates.size());
1037 }
1038
1039 int32_t stride_rows = strides[1];
1040 int32_t stride_cols = strides[2];
1041
1042 int32_t rate_rows = rates[1];
1043 int32_t rate_cols = rates[2];
1044
1045 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1046 DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
1047 DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
1048 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
1049 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
1050 DimensionHandle output_depth_dim = c->Dim(filter_shape, 2);
1051
1052 if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim) ||
1053 !c->ValueKnown(filter_rows_dim) || !c->ValueKnown(filter_cols_dim)) {
1054 ShapeHandle output_shape =
1055 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
1056 InferenceContext::kUnknownDim, output_depth_dim});
1057 c->set_output(0, output_shape);
1058 return OkStatus();
1059 }
1060 DimensionHandle unused;
1061 TF_RETURN_IF_ERROR(
1062 c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused));
1063
1064 auto in_rows = c->Value(in_rows_dim);
1065 auto in_cols = c->Value(in_cols_dim);
1066 auto filter_rows = c->Value(filter_rows_dim);
1067 auto filter_cols = c->Value(filter_cols_dim);
1068 auto filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_rows - 1);
1069 auto filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_cols - 1);
1070
1071 Padding padding;
1072 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1073
1074 int64_t output_rows, output_cols;
1075 int64_t padding_before, padding_after;
1076 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
1077 in_rows, filter_rows_eff, stride_rows, padding, &output_rows,
1078 &padding_before, &padding_after));
1079 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
1080 in_cols, filter_cols_eff, stride_cols, padding, &output_cols,
1081 &padding_before, &padding_after));
1082
1083 ShapeHandle output_shape = c->MakeShape(
1084 {batch_size_dim, output_rows, output_cols, output_depth_dim});
1085 c->set_output(0, output_shape);
1086 return OkStatus();
1087 });
1088
1089 REGISTER_OP("Dilation2DBackpropInput")
1090 .Input("input: T")
1091 .Input("filter: T")
1092 .Input("out_backprop: T")
1093 .Output("in_backprop: T")
1094 .Attr("T: realnumbertype")
1095 .Attr("strides: list(int) >= 4")
1096 .Attr("rates: list(int) >= 4")
1097 .Attr(GetPaddingAttrString())
1098 .SetShapeFn(shape_inference::UnchangedShape);
1099
1100 REGISTER_OP("Dilation2DBackpropFilter")
1101 .Input("input: T")
1102 .Input("filter: T")
1103 .Input("out_backprop: T")
1104 .Output("filter_backprop: T")
1105 .Attr("T: realnumbertype")
1106 .Attr("strides: list(int) >= 4")
1107 .Attr("rates: list(int) >= 4")
1108 .Attr(GetPaddingAttrString())
__anon84cbdd651802(InferenceContext* c) 1109 .SetShapeFn([](InferenceContext* c) {
1110 c->set_output(0, c->input(1));
1111 return OkStatus();
1112 });
1113
1114 // --------------------------------------------------------------------------
1115
1116 REGISTER_OP("Relu")
1117 .Input("features: T")
1118 .Output("activations: T")
1119 .Attr("T: {realnumbertype, qint8}")
1120 .SetShapeFn(shape_inference::UnchangedShape);
1121
1122 REGISTER_OP("ReluGrad")
1123 .Input("gradients: T")
1124 .Input("features: T")
1125 .Output("backprops: T")
1126 .Attr("T: realnumbertype")
1127 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1128
1129 REGISTER_OP("Relu6")
1130 .Input("features: T")
1131 .Output("activations: T")
1132 .Attr("T: realnumbertype")
1133 .SetShapeFn(shape_inference::UnchangedShape);
1134
1135 REGISTER_OP("Relu6Grad")
1136 .Input("gradients: T")
1137 .Input("features: T")
1138 .Output("backprops: T")
1139 .Attr("T: realnumbertype")
1140 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1141
1142 REGISTER_OP("LeakyRelu")
1143 .Input("features: T")
1144 .Output("activations: T")
1145 .Attr("alpha: float = 0.2")
1146 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
1147 .SetShapeFn(shape_inference::UnchangedShape);
1148
1149 REGISTER_OP("LeakyReluGrad")
1150 .Input("gradients: T")
1151 .Input("features: T")
1152 .Output("backprops: T")
1153 .Attr("alpha: float = 0.2")
1154 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
1155 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1156
1157 REGISTER_OP("Elu")
1158 .Input("features: T")
1159 .Output("activations: T")
1160 .Attr("T: {half, bfloat16, float, double}")
1161 .SetShapeFn(shape_inference::UnchangedShape);
1162
1163 REGISTER_OP("EluGrad")
1164 .Input("gradients: T")
1165 .Input("outputs: T")
1166 .Output("backprops: T")
1167 .Attr("T: {half, bfloat16, float, double}")
1168 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1169
1170 REGISTER_OP("Selu")
1171 .Input("features: T")
1172 .Output("activations: T")
1173 .Attr("T: {half, bfloat16, float, double}")
1174 .SetShapeFn(shape_inference::UnchangedShape);
1175
1176 REGISTER_OP("SeluGrad")
1177 .Input("gradients: T")
1178 .Input("outputs: T")
1179 .Output("backprops: T")
1180 .Attr("T: {half, bfloat16, float, double}")
1181 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1182
1183 REGISTER_OP("Softplus")
1184 .Input("features: T")
1185 .Output("activations: T")
1186 .Attr("T: {half, bfloat16, float, double}")
1187 .SetShapeFn(shape_inference::UnchangedShape);
1188
1189 REGISTER_OP("SoftplusGrad")
1190 .Input("gradients: T")
1191 .Input("features: T")
1192 .Output("backprops: T")
1193 .Attr("T: {half, bfloat16, float, double}")
1194 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1195
1196 REGISTER_OP("Softsign")
1197 .Input("features: T")
1198 .Output("activations: T")
1199 .Attr("T: {half, bfloat16, float, double}")
1200 .SetShapeFn(shape_inference::UnchangedShape);
1201
1202 REGISTER_OP("SoftsignGrad")
1203 .Input("gradients: T")
1204 .Input("features: T")
1205 .Output("backprops: T")
1206 .Attr("T: {half, bfloat16, float, double}")
1207 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1208
1209 // --------------------------------------------------------------------------
1210
1211 REGISTER_OP("Softmax")
1212 .Input("logits: T")
1213 .Output("softmax: T")
1214 .Attr("T: {half, bfloat16, float, double}")
__anon84cbdd651902(InferenceContext* c) 1215 .SetShapeFn([](InferenceContext* c) {
1216 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
1217 });
1218
1219 // --------------------------------------------------------------------------
1220
1221 REGISTER_OP("LogSoftmax")
1222 .Input("logits: T")
1223 .Output("logsoftmax: T")
1224 .Attr("T: {half, bfloat16, float, double}")
__anon84cbdd651a02(InferenceContext* c) 1225 .SetShapeFn([](InferenceContext* c) {
1226 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
1227 });
1228
1229 // --------------------------------------------------------------------------
1230
1231 REGISTER_OP("SoftmaxCrossEntropyWithLogits")
1232 .Input("features: T")
1233 .Input("labels: T")
1234 .Output("loss: T")
1235 .Output("backprop: T")
1236 .Attr("T: {half, bfloat16, float, double}")
__anon84cbdd651b02(InferenceContext* c) 1237 .SetShapeFn([](InferenceContext* c) {
1238 ShapeHandle input;
1239 if (c->WithRank(c->input(0), 2, &input) == OkStatus() &&
1240 c->Merge(input, c->input(1), &input) == OkStatus()) {
1241 DimensionHandle batch_size = c->Dim(input, 0);
1242 c->set_output(0, c->Vector(batch_size));
1243 c->set_output(1, input);
1244 return OkStatus();
1245 }
1246 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFn(c, 1));
1247
1248 if (!c->RankKnown(c->output(1))) {
1249 return errors::InvalidArgument(
1250 "Shape must be broadcasted with rank 2, but is rank is unknown.");
1251 }
1252
1253 if (c->Rank(c->output(1)) != 2) {
1254 return errors::InvalidArgument(
1255 "Shape must be broadcasted with rank 2, but is rank ",
1256 c->Rank(c->output(1)));
1257 }
1258 DimensionHandle batch_size = c->Dim(c->output(1), 0);
1259 c->set_output(0, c->Vector(batch_size));
1260 return OkStatus();
1261 });
1262
1263 REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits")
1264 .Input("features: T")
1265 .Input("labels: Tlabels")
1266 .Output("loss: T")
1267 .Output("backprop: T")
1268 .Attr("T: {half, bfloat16, float, double}")
1269 .Attr("Tlabels: {int32, int64} = DT_INT64")
__anon84cbdd651c02(InferenceContext* c) 1270 .SetShapeFn([](InferenceContext* c) {
1271 ShapeHandle features;
1272 ShapeHandle labels;
1273 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &features));
1274 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &labels));
1275
1276 DimensionHandle batch_size;
1277 TF_RETURN_IF_ERROR(
1278 c->Merge(c->Dim(features, 0), c->Dim(labels, 0), &batch_size));
1279 TF_RETURN_IF_ERROR(c->ReplaceDim(features, 0, batch_size, &features));
1280
1281 c->set_output(0, c->Vector(batch_size));
1282 c->set_output(1, features);
1283 return OkStatus();
1284 });
1285
1286 // --------------------------------------------------------------------------
1287
1288 REGISTER_OP("InTopK")
1289 .Input("predictions: float")
1290 .Input("targets: T")
1291 .Output("precision: bool")
1292 .Attr("k: int")
1293 .Attr("T: {int32, int64} = DT_INT32")
__anon84cbdd651d02(InferenceContext* c) 1294 .SetShapeFn([](InferenceContext* c) {
1295 ShapeHandle predictions;
1296 ShapeHandle targets;
1297 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
1298 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
1299 DimensionHandle batch_size;
1300 TF_RETURN_IF_ERROR(
1301 c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
1302 c->set_output(0, c->Vector(batch_size));
1303 return OkStatus();
1304 });
1305
1306 // This is the same as `InTopK`, but takes `k` as in input rather than an attr.
1307 REGISTER_OP("InTopKV2")
1308 .Input("predictions: float")
1309 .Input("targets: T")
1310 .Input("k: T")
1311 .Output("precision: bool")
1312 .Attr("T: {int32, int64} = DT_INT32")
__anon84cbdd651e02(InferenceContext* c) 1313 .SetShapeFn([](InferenceContext* c) {
1314 ShapeHandle predictions;
1315 ShapeHandle targets;
1316 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
1317 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
1318 DimensionHandle batch_size;
1319 TF_RETURN_IF_ERROR(
1320 c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
1321 c->set_output(0, c->Vector(batch_size));
1322 return OkStatus();
1323 });
1324
1325 namespace {
1326
TopKShapeFn(InferenceContext * c)1327 Status TopKShapeFn(InferenceContext* c) {
1328 ShapeHandle input;
1329 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1330
1331 // Get the k value, either from input tensor or attribute.
1332 DimensionHandle k_dim;
1333 if (c->num_inputs() >= 2) {
1334 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &k_dim));
1335 } else {
1336 int32_t k;
1337 TF_RETURN_IF_ERROR(c->GetAttr("k", &k));
1338 if (k < 0) {
1339 return errors::InvalidArgument("Need k >= 0, got ", k);
1340 }
1341 k_dim = c->MakeDim(k);
1342 }
1343
1344 DimensionHandle last_dim = c->Dim(input, -1);
1345 if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) &&
1346 c->Value(last_dim) < c->Value(k_dim)) {
1347 return errors::InvalidArgument(
1348 "input must have last dimension >= k = ", c->Value(k_dim), " but is ",
1349 c->Value(last_dim));
1350 }
1351
1352 // Replace last_dim with k_dim.
1353 ShapeHandle s;
1354 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
1355 TF_RETURN_IF_ERROR(c->Concatenate(s, c->Vector(k_dim), &s));
1356 c->set_output(0, s);
1357 c->set_output(1, s);
1358 return OkStatus();
1359 }
1360
1361 // Utility functions for ApproxTopKShape.
1362 // It is not easy to link xla/client/lib into the tensorflow core lib, so we
1363 // have to replicate the logic.
1364 // LINT.IfChange
log2_floor(uint64_t value)1365 inline uint32_t log2_floor(uint64_t value) {
1366 return value == 0 ? 0 : Log2Floor(value);
1367 }
1368
log2_ceil(uint64_t value)1369 inline uint32_t log2_ceil(uint64_t value) {
1370 return value == 0 ? 0 : Log2Ceiling(value);
1371 }
1372
ApproxTopKShape(shape_inference::InferenceContext * c)1373 Status ApproxTopKShape(shape_inference::InferenceContext* c) {
1374 int64_t k;
1375 int64_t reduction_dimension;
1376 float recall_target;
1377 int64_t reduction_input_size_override;
1378 bool aggregate_to_topk;
1379 TF_RETURN_IF_ERROR(c->GetAttr("k", &k));
1380 TF_RETURN_IF_ERROR(c->GetAttr("reduction_dimension", &reduction_dimension));
1381 TF_RETURN_IF_ERROR(c->GetAttr("recall_target", &recall_target));
1382 TF_RETURN_IF_ERROR(c->GetAttr("reduction_input_size_override",
1383 &reduction_input_size_override));
1384 TF_RETURN_IF_ERROR(c->GetAttr("aggregate_to_topk", &aggregate_to_topk));
1385 ShapeHandle input_shape;
1386 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
1387 if (reduction_dimension < 0) {
1388 // Reverse index
1389 reduction_dimension += c->Rank(input_shape);
1390 }
1391 int64_t reduction_dim_value =
1392 c->Value(c->Dim(input_shape, reduction_dimension));
1393
1394 if (reduction_dim_value < k) {
1395 return errors::InvalidArgument("input must have last dimension >= k = ", k,
1396 " but was ", reduction_dim_value);
1397 }
1398
1399 int64_t output_dim_value = [&] {
1400 if (aggregate_to_topk) {
1401 return k;
1402 }
1403 int64_t tpu_tiling = c->Rank(input_shape) == 1 ? 1024 : 128;
1404 if (reduction_dim_value <= tpu_tiling || recall_target == 1.0) {
1405 return reduction_dim_value;
1406 }
1407 if (k == 1) {
1408 return tpu_tiling;
1409 }
1410 uint64_t logical_input_size = reduction_input_size_override >= 0
1411 ? reduction_input_size_override
1412 : reduction_dim_value;
1413 uint64_t m = std::min<uint64_t>(
1414 std::max<uint64_t>(
1415 static_cast<uint64_t>((1.0 - k) /
1416 std::log(static_cast<double>(recall_target))),
1417 tpu_tiling),
1418 reduction_dim_value);
1419 uint32_t log2_reduction = log2_floor(logical_input_size / m);
1420 if (log2_reduction == 0) {
1421 return reduction_dim_value;
1422 }
1423 log2_reduction = std::min<uint32_t>(
1424 log2_reduction, log2_ceil(reduction_dim_value / tpu_tiling));
1425 return tensorflow::MathUtil::CeilOfRatio<int64_t>(
1426 tensorflow::MathUtil::CeilOfRatio<int64_t>(reduction_dim_value,
1427 tpu_tiling),
1428 (1 << log2_reduction)) *
1429 tpu_tiling;
1430 }();
1431
1432 auto output_dim = c->MakeDim(output_dim_value);
1433
1434 ShapeHandle output_shape;
1435 TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, reduction_dimension, output_dim,
1436 &output_shape));
1437 c->set_output(0, output_shape);
1438 c->set_output(1, output_shape);
1439 return OkStatus();
1440 }
1441 // LINT.ThenChange(//tensorflow/compiler/xla/client/lib/approx_topk_shape.cc)
1442
1443 } // namespace
1444
1445 REGISTER_OP("TopK")
1446 .Input("input: T")
1447 .Output("values: T")
1448 .Output("indices: int32")
1449 .Attr("k: int >= 0")
1450 .Attr("sorted: bool = true")
1451 .Attr("T: realnumbertype")
1452 .Deprecated(7, "Use TopKV2 instead")
1453 .SetShapeFn(TopKShapeFn);
1454
1455 // This is the same as `TopK`, but takes `k` as in input rather than an attr.
1456 REGISTER_OP("TopKV2")
1457 .Input("input: T")
1458 .Input("k: int32")
1459 .Output("values: T")
1460 .Output("indices: int32")
1461 .Attr("sorted: bool = true")
1462 .Attr("T: realnumbertype")
1463 .SetShapeFn(TopKShapeFn);
1464
1465 REGISTER_OP("ApproxTopK")
1466 .Input("input: T")
1467 .Output("values: T")
1468 .Output("indices: int32")
1469 .Attr("k: int >= 0")
1470 .Attr("reduction_dimension: int = -1")
1471 .Attr("recall_target: float = 0.95")
1472 .Attr("is_max_k: bool = true")
1473 .Attr("reduction_input_size_override: int = -1")
1474 .Attr("aggregate_to_topk: bool = true")
1475 .Attr("T: {half, bfloat16, float}")
1476 .SetShapeFn(ApproxTopKShape);
1477
1478 // --------------------------------------------------------------------------
1479
1480 REGISTER_OP("NthElement")
1481 .Input("input: T")
1482 .Input("n: int32")
1483 .Output("values: T")
1484 .Attr("reverse: bool = false")
1485 .Attr("T: realnumbertype")
__anon84cbdd652102(InferenceContext* c) 1486 .SetShapeFn([](InferenceContext* c) {
1487 ShapeHandle input;
1488 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1489
1490 // Get the n value from input tensor, and make sure which is a scalar.
1491 DimensionHandle n_dim;
1492 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &n_dim));
1493
1494 // The last dimension of input tensor must be greater than N.
1495 DimensionHandle last_dim = c->Dim(input, -1);
1496 if (c->ValueKnown(last_dim) && c->ValueKnown(n_dim) &&
1497 c->Value(last_dim) <= c->Value(n_dim)) {
1498 return errors::InvalidArgument(
1499 "Input must have last dimension > n = ", c->Value(n_dim),
1500 " but is ", c->Value(last_dim));
1501 }
1502
1503 // Reduce last_dim for output tensor
1504 ShapeHandle s;
1505 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
1506 c->set_output(0, s);
1507 return OkStatus();
1508 });
1509
1510 // --------------------------------------------------------------------------
1511
1512 REGISTER_OP("FractionalMaxPool")
1513 .Input("value: T")
1514 .Output("output: T")
1515 .Output("row_pooling_sequence: int64")
1516 .Output("col_pooling_sequence: int64")
1517 .Attr("pooling_ratio: list(float) >=4")
1518 .Attr("pseudo_random: bool = false")
1519 .Attr("overlapping: bool = false")
1520 .Attr("deterministic: bool = false")
1521 .Attr("seed: int = 0")
1522 .Attr("seed2: int = 0")
1523 .Attr("T: {float, double, int32, int64}")
1524 .SetShapeFn(FractionalPoolShapeFn);
1525
1526 REGISTER_OP("FractionalMaxPoolGrad")
1527 .Input("orig_input: T")
1528 .Input("orig_output: T")
1529 .Input("out_backprop: T")
1530 .Input("row_pooling_sequence: int64")
1531 .Input("col_pooling_sequence: int64")
1532 .Output("output: T")
1533 .Attr("overlapping: bool = false")
1534 .Attr("T: {float, double, int32, int64}")
__anon84cbdd652202(InferenceContext* c) 1535 .SetShapeFn([](InferenceContext* c) {
1536 return shape_inference::UnchangedShapeWithRank(c, 4);
1537 });
1538
1539 // --------------------------------------------------------------------------
1540
1541 REGISTER_OP("FractionalAvgPool")
1542 .Input("value: T")
1543 .Output("output: T")
1544 .Output("row_pooling_sequence: int64")
1545 .Output("col_pooling_sequence: int64")
1546 .Attr("pooling_ratio: list(float) >=4")
1547 .Attr("pseudo_random: bool = false")
1548 .Attr("overlapping: bool = false")
1549 .Attr("deterministic: bool = false")
1550 .Attr("seed: int = 0")
1551 .Attr("seed2: int = 0")
1552 .Attr("T: {float, double, int32, int64}")
1553 .SetShapeFn(FractionalPoolShapeFn);
1554
1555 REGISTER_OP("FractionalAvgPoolGrad")
1556 .Input("orig_input_tensor_shape: int64")
1557 .Input("out_backprop: T")
1558 .Input("row_pooling_sequence: int64")
1559 .Input("col_pooling_sequence: int64")
1560 .Output("output: T")
1561 .Attr("overlapping: bool = false")
1562 .Attr("T: {float, double, int32, int64}")
__anon84cbdd652302(InferenceContext* c) 1563 .SetShapeFn([](InferenceContext* c) {
1564 if (c->input_tensor(0) != nullptr) {
1565 ShapeHandle out;
1566 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1567 c->set_output(0, out);
1568 } else {
1569 c->set_output(0, c->UnknownShapeOfRank(4));
1570 }
1571 return OkStatus();
1572 });
1573
1574 REGISTER_OP("QuantizedAvgPool")
1575 .Input("input: T")
1576 .Input("min_input: float")
1577 .Input("max_input: float")
1578 .Output("output: T")
1579 .Output("min_output: float")
1580 .Output("max_output: float")
1581 .Attr("T: quantizedtype")
1582 .Attr("ksize: list(int)")
1583 .Attr("strides: list(int)")
1584 .Attr(GetPaddingAttrString())
1585 .SetShapeFn(shape_inference::QuantizedAvgPoolShape);
1586
1587 REGISTER_OP("QuantizedBiasAdd")
1588 .Input("input: T1")
1589 .Input("bias: T2")
1590 .Input("min_input: float")
1591 .Input("max_input: float")
1592 .Input("min_bias: float")
1593 .Input("max_bias: float")
1594 .Output("output: out_type")
1595 .Output("min_out: float")
1596 .Output("max_out: float")
1597 .Attr("T1: quantizedtype")
1598 .Attr("T2: quantizedtype")
1599 .Attr("out_type: quantizedtype")
__anon84cbdd652402(InferenceContext* c) 1600 .SetShapeFn([](InferenceContext* c) {
1601 TF_RETURN_IF_ERROR(shape_inference::BiasAddShape(c));
1602 ShapeHandle unused;
1603 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1604 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1605 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1606 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1607 c->set_output(1, c->Scalar());
1608 c->set_output(2, c->Scalar());
1609 return OkStatus();
1610 });
1611
1612 REGISTER_OP("QuantizedConv2D")
1613 .Input("input: Tinput")
1614 .Input("filter: Tfilter")
1615 .Input("min_input: float")
1616 .Input("max_input: float")
1617 .Input("min_filter: float")
1618 .Input("max_filter: float")
1619 .Output("output: out_type")
1620 .Output("min_output: float")
1621 .Output("max_output: float")
1622 .Attr("Tinput: quantizedtype")
1623 .Attr("Tfilter: quantizedtype")
1624 .Attr("out_type: quantizedtype = DT_QINT32")
1625 .Attr("strides: list(int)")
1626 .Attr(GetPaddingAttrString())
1627 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1628 .SetShapeFn(shape_inference::QuantizedConv2DShape);
1629
1630 REGISTER_OP("QuantizedMaxPool")
1631 .Input("input: T")
1632 .Input("min_input: float")
1633 .Input("max_input: float")
1634 .Output("output: T")
1635 .Output("min_output: float")
1636 .Output("max_output: float")
1637 .Attr("T: quantizedtype")
1638 .Attr("ksize: list(int)")
1639 .Attr("strides: list(int)")
1640 .Attr(GetPaddingAttrString())
__anon84cbdd652502(InferenceContext* c) 1641 .SetShapeFn([](InferenceContext* c) {
1642 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
1643 ShapeHandle unused;
1644 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1645 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1646 c->set_output(1, c->Scalar());
1647 c->set_output(2, c->Scalar());
1648 return OkStatus();
1649 });
1650
1651 REGISTER_OP("QuantizedRelu")
1652 .Input("features: Tinput")
1653 .Input("min_features: float")
1654 .Input("max_features: float")
1655 .Output("activations: out_type")
1656 .Output("min_activations: float")
1657 .Output("max_activations: float")
1658 .Attr("Tinput: quantizedtype")
1659 .Attr("out_type: quantizedtype = DT_QUINT8")
__anon84cbdd652602(InferenceContext* c) 1660 .SetShapeFn([](InferenceContext* c) {
1661 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1662 ShapeHandle unused;
1663 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1664 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1665 c->set_output(1, c->Scalar());
1666 c->set_output(2, c->Scalar());
1667 return OkStatus();
1668 });
1669
1670 REGISTER_OP("QuantizedRelu6")
1671 .Input("features: Tinput")
1672 .Input("min_features: float")
1673 .Input("max_features: float")
1674 .Output("activations: out_type")
1675 .Output("min_activations: float")
1676 .Output("max_activations: float")
1677 .Attr("Tinput: quantizedtype")
1678 .Attr("out_type: quantizedtype = DT_QUINT8")
__anon84cbdd652702(InferenceContext* c) 1679 .SetShapeFn([](InferenceContext* c) {
1680 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1681 ShapeHandle unused;
1682 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1683 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1684 c->set_output(1, c->Scalar());
1685 c->set_output(2, c->Scalar());
1686 return OkStatus();
1687 });
1688
1689 REGISTER_OP("QuantizedReluX")
1690 .Input("features: Tinput")
1691 .Input("max_value: float")
1692 .Input("min_features: float")
1693 .Input("max_features: float")
1694 .Output("activations: out_type")
1695 .Output("min_activations: float")
1696 .Output("max_activations: float")
1697 .Attr("Tinput: quantizedtype")
1698 .Attr("out_type: quantizedtype = DT_QUINT8")
__anon84cbdd652802(InferenceContext* c) 1699 .SetShapeFn([](InferenceContext* c) {
1700 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1701 ShapeHandle unused;
1702 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1703 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1704 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1705 c->set_output(1, c->Scalar());
1706 c->set_output(2, c->Scalar());
1707 return OkStatus();
1708 });
1709
1710 REGISTER_OP("QuantizedBatchNormWithGlobalNormalization")
1711 .Input("t: Tinput")
1712 .Input("t_min: float")
1713 .Input("t_max: float")
1714 .Input("m: Tinput")
1715 .Input("m_min: float")
1716 .Input("m_max: float")
1717 .Input("v: Tinput")
1718 .Input("v_min: float")
1719 .Input("v_max: float")
1720 .Input("beta: Tinput")
1721 .Input("beta_min: float")
1722 .Input("beta_max: float")
1723 .Input("gamma: Tinput")
1724 .Input("gamma_min: float")
1725 .Input("gamma_max: float")
1726 .Output("result: out_type")
1727 .Output("result_min: float")
1728 .Output("result_max: float")
1729 .Attr("Tinput: quantizedtype")
1730 .Attr("out_type: quantizedtype")
1731 .Attr("variance_epsilon: float")
1732 .Attr("scale_after_normalization: bool")
__anon84cbdd652902(InferenceContext* c) 1733 .SetShapeFn([](InferenceContext* c) {
1734 ShapeHandle input;
1735 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
1736
1737 DimensionHandle last_dim = c->Dim(input, 3);
1738 for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma
1739 ShapeHandle vec;
1740 TF_RETURN_IF_ERROR(c->WithRank(c->input(i * 3), 1, &vec));
1741 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
1742 }
1743
1744 ShapeHandle out;
1745 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
1746 c->set_output(0, out);
1747 c->set_output(1, c->Scalar());
1748 c->set_output(2, c->Scalar());
1749
1750 return OkStatus();
1751 });
1752
1753 #ifdef INTEL_MKL
1754 REGISTER_OP("_MklDepthwiseConv2dNative")
1755 .Input("input: T")
1756 .Input("filter: T")
1757 .Input("mkl_input: uint8")
1758 .Input("mkl_filter: uint8")
1759 .Output("output: T")
1760 .Output("filter_output: T")
1761 .Output("mkl_output: uint8")
1762 .Output("mkl_filter_output: uint8")
1763 .Attr("T: {half, bfloat16, float, double}")
1764 .Attr("strides: list(int)")
1765 .Attr("is_filter_const: bool = false")
1766 .Attr(GetPaddingAttrStringWithExplicit())
1767 .Attr(GetConvnetDataFormatAttrString())
1768 .Attr(GetExplicitPaddingsAttrString())
1769 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1770 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShapeWithExplicitPadding);
1771
1772 REGISTER_OP("_MklConv2D")
1773 .Input("input: T")
1774 .Input("filter: T")
1775 .Input("mkl_input: uint8")
1776 .Input("mkl_filter: uint8")
1777 .Output("output: T")
1778 .Output("filter_output: T")
1779 .Output("mkl_output: uint8")
1780 .Output("mkl_filter_output: uint8")
1781 .Attr("T: {bfloat16, float}")
1782 .Attr("strides: list(int)")
1783 .Attr("use_cudnn_on_gpu: bool = true")
1784 .Attr("is_filter_const: bool = false")
1785 .Attr(GetPaddingAttrStringWithExplicit())
1786 .Attr(GetConvnetDataFormatAttrString())
1787 .Attr(GetExplicitPaddingsAttrString())
1788 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1789 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1790 .Doc(R"doc(
1791 MKL version of Conv2D operator. Uses MKL DNN APIs to perform 2D convolution.
1792
1793 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1794 expected to invoke these operators.
1795 )doc");
1796
1797 REGISTER_OP("_MklNativeConv2D")
1798 .Input("input: T")
1799 .Input("filter: T")
1800 .Output("output: T")
1801 .Attr("T: {bfloat16, float}")
1802 .Attr("strides: list(int)")
1803 .Attr("use_cudnn_on_gpu: bool = true")
1804 .Attr("is_filter_const: bool = false")
1805 .Attr(GetPaddingAttrStringWithExplicit())
1806 .Attr(GetExplicitPaddingsAttrString())
1807 .Attr(GetConvnetDataFormatAttrString())
1808 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1809 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1810 .Doc(R"doc(
1811 MKL version of Conv2D operator for Eager mode. Uses MKL DNN APIs to perform 2D convolution.
1812
1813 NOTE Do not invoke this operator directly in Python. Eager Op rewrite is
1814 expected to invoke these operators.
1815 )doc");
1816
1817 REGISTER_OP("__MklDummyConv2DWithBias")
1818 .Input("input: T")
1819 .Input("filter: T")
1820 .Input("bias: T")
1821 .Output("output: T")
1822 .Attr("T: {bfloat16, float}")
1823 .Attr("strides: list(int)")
1824 .Attr("use_cudnn_on_gpu: bool = true")
1825 .Attr("is_filter_const: bool = false")
1826 .Attr(GetPaddingAttrStringWithExplicit())
1827 .Attr(GetExplicitPaddingsAttrString())
1828 .Attr(GetConvnetDataFormatAttrString())
1829 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1830 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1831 .Doc(R"doc(
1832 Dummy node that enables fusing Conv2D and BiasAdd operator for MKL. This node
1833 does not perform anything. It is just created as an intermediate output of
1834 merging Conv2D and BiasAdd.
1835
1836 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1837 expected to invoke these operators.
1838 )doc");
1839
1840 REGISTER_OP("_MklConv2DWithBias")
1841 .Input("input: T")
1842 .Input("filter: T")
1843 .Input("bias: T")
1844 .Input("mkl_input: uint8")
1845 .Input("mkl_filter: uint8")
1846 .Input("mkl_bias: uint8")
1847 .Output("output: T")
1848 .Output("filter_output: T")
1849 .Output("mkl_output: uint8")
1850 .Output("mkl_filter_output: uint8")
1851 .Attr("T: {bfloat16, float}")
1852 .Attr("strides: list(int)")
1853 .Attr("use_cudnn_on_gpu: bool = true")
1854 .Attr("is_filter_const: bool = false")
1855 .Attr(GetPaddingAttrStringWithExplicit())
1856 .Attr(GetExplicitPaddingsAttrString())
1857 .Attr(GetConvnetDataFormatAttrString())
1858 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1859 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1860 .Doc(R"doc(
1861 MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform
1862 2D convolution and add Bias to the output of convolution.
1863
1864 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1865 expected to invoke these operators.
1866 )doc");
1867
1868 REGISTER_OP("__MklDummyPadWithConv2D")
1869 .Input("input: T")
1870 .Input("filter: T")
1871 .Input("paddings: Tpaddings")
1872 .Output("output: T")
1873 .Attr("T: {bfloat16, float}")
1874 .Attr("strides: list(int)")
1875 .Attr("use_cudnn_on_gpu: bool = true")
1876 .Attr(GetPaddingAttrString())
1877 .Attr(GetConvnetDataFormatAttrString())
1878 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1879 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1880 .SetShapeFn(shape_inference::Conv2DShape)
1881 .Doc(R"doc(
1882 Dummy node that enables fusing Pad and Conv2D operator for MKL. This node
1883 does not perform anything. It is just created as an intermediate output of
1884 merging Pad and Conv2D.
1885
1886 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1887 expected to invoke these operators.
1888 )doc");
1889
1890 REGISTER_OP("_MklPadWithConv2D")
1891 .Input("input: T")
1892 .Input("filter: T")
1893 .Input("paddings: Tpaddings")
1894 .Input("mkl_input: uint8")
1895 .Input("mkl_filter: uint8")
1896 .Input("mkl_paddings: uint8")
1897 .Output("output: T")
1898 .Output("filter_output: T")
1899 .Output("mkl_output: uint8")
1900 .Output("mkl_filter_output: uint8")
1901 .Attr("T: {bfloat16, float}")
1902 .Attr("strides: list(int)")
1903 .Attr("use_cudnn_on_gpu: bool = true")
1904 .Attr(GetPaddingAttrString())
1905 .Attr(GetConvnetDataFormatAttrString())
1906 .Attr("is_filter_const: bool = false")
1907 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1908 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1909 .SetShapeFn(shape_inference::Conv2DShape)
1910 .Doc(R"doc(
1911 MKL version of Pad and Conv2D operator. Uses MKL DNN APIs to perform
1912 Pad and 2D convolution to the output of convolution.
1913
1914 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1915 expected to invoke these operators.
1916 )doc");
1917
1918 REGISTER_OP("_MklConv2DBackpropFilter")
1919 .Input("input: T")
1920 .Input("filter_sizes: int32")
1921 .Input("out_backprop: T")
1922 .Input("mkl_input: uint8")
1923 .Input("mkl_filter_size: uint8")
1924 .Input("mkl_out_backprop: uint8")
1925 .Output("output: T")
1926 .Output("mkl_output: uint8")
1927 .Attr("T: {bfloat16, float}")
1928 .Attr("strides: list(int)")
1929 .Attr("use_cudnn_on_gpu: bool = true")
1930 .Attr(GetPaddingAttrString())
1931 .Attr(GetConvnetDataFormatAttrString())
1932 .Attr(GetExplicitPaddingsAttrString())
1933 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd652a02(InferenceContext* c) 1934 .SetShapeFn([](InferenceContext* c) {
1935 ShapeHandle s;
1936 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
1937 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1938 c->set_output(0, s);
1939 return Status::OK();
1940 })
1941 .Doc(R"doc(
1942 MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
1943 gradients of convolution with respect to the filter.
1944
1945 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1946 expected to invoke these operators.
1947 )doc");
1948
1949 REGISTER_OP("_MklNativeConv2DBackpropFilter")
1950 .Input("input: T")
1951 .Input("filter_sizes: int32")
1952 .Input("out_backprop: T")
1953 .Output("output: T")
1954 .Attr("T: {bfloat16, float}")
1955 .Attr("strides: list(int)")
1956 .Attr("use_cudnn_on_gpu: bool = true")
1957 .Attr(GetPaddingAttrStringWithExplicit())
1958 .Attr(GetExplicitPaddingsAttrString())
1959 .Attr(GetConvnetDataFormatAttrString())
1960 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd652b02(InferenceContext* c) 1961 .SetShapeFn([](InferenceContext* c) {
1962 ShapeHandle s;
1963 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
1964 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1965 c->set_output(0, s);
1966 return Status::OK();
1967 })
1968 .Doc(R"doc(
1969 MKL version of Conv2DBackpropFilter for Eager mode. Uses MKL DNN APIs
1970 to compute the gradients of convolution with respect to the filter.
1971
1972 NOTE Do not invoke this operator directly in Python. Eager Op rewrite pass is
1973 expected to invoke these operators.
1974 )doc");
1975
1976 REGISTER_OP("__MklDummyConv2DBackpropFilterWithBias")
1977 .Input("input: T")
1978 .Input("filter_sizes: int32")
1979 .Input("out_backprop: T")
1980 .Output("output: T")
1981 .Output("bias_grad: T")
1982 .Attr("T: {bfloat16, float}")
1983 .Attr("strides: list(int)")
1984 .Attr("use_cudnn_on_gpu: bool = true")
1985 .Attr(GetPaddingAttrString())
1986 .Attr(GetConvnetDataFormatAttrString())
1987 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd652c02(InferenceContext* c) 1988 .SetShapeFn([](InferenceContext* c) {
1989 ShapeHandle input_shape;
1990 // Fetch the data_format attribute, which may not exist.
1991 string data_format;
1992 Status s = c->GetAttr("data_format", &data_format);
1993
1994 if (s.ok() && data_format == "NCHW") {
1995 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1996 c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
1997 } else {
1998 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1999 c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
2000 }
2001 ShapeHandle sh;
2002 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
2003 TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
2004 c->set_output(0, sh);
2005 return Status::OK();
2006 })
2007 .Doc(R"doc(
2008 Dummy node that enables fusing Conv2DBackpropFilter and BiasAddGrad operator
2009 for MKL. This node does not perform anything. It is just created as an
2010 intermediate output of merging Conv2DBackpropFilter and BiasAddGrad.
2011
2012 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2013 expected to invoke these operators.
2014 )doc");
2015
2016 REGISTER_OP("_MklConv2DBackpropFilterWithBias")
2017 .Input("input: T")
2018 .Input("filter_sizes: int32")
2019 .Input("out_backprop: T")
2020 .Input("mkl_input: uint8")
2021 .Input("mkl_filter_size: uint8")
2022 .Input("mkl_out_backprop: uint8")
2023 .Output("output: T")
2024 .Output("bias_grad: T")
2025 .Output("mkl_output: uint8")
2026 .Output("mkl_bias_grad: uint8")
2027 .Attr("T: {bfloat16, float}")
2028 .Attr("strides: list(int)")
2029 .Attr("use_cudnn_on_gpu: bool = true")
2030 .Attr(GetPaddingAttrString())
2031 .Attr(GetConvnetDataFormatAttrString())
2032 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2033 .SetShapeFn(shape_inference::Conv2DBackpropFilterWithBiasShape)
2034 .Doc(R"doc(
2035 MKL version of Conv2DBackpropFilterWithBias. Uses MKL DNN APIs to compute the
2036 gradients of convolution with respect to the filter.
2037
2038 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2039 expected to invoke these operators.
2040 )doc");
2041
2042 #ifdef INTEL_MKL_ML_ONLY
2043 REGISTER_OP("_MklConv2DWithBiasBackpropBias")
2044 .Input("out_backprop: T")
2045 .Input("mkl_out_backprop: uint8")
2046 .Output("output: T")
2047 .Output("mkl_output: uint8")
2048 .Attr("T: {half, float, double}")
2049 .Attr("strides: list(int)")
2050 .Attr(GetConvnetDataFormatAttrString())
2051 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2052 .Doc(R"doc(
2053 MKL version of Conv2DBackpropBias. Uses MKL DNN APIs to compute the
2054 gradients of convolution with respect to the bias.
2055
2056 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2057 expected to invoke these operators.
2058 )doc");
2059 #endif
2060
2061 REGISTER_OP("_MklConv2DBackpropInput")
2062 .Input("input_sizes: int32")
2063 .Input("filter: T")
2064 .Input("out_backprop: T")
2065 .Input("mkl_input_sizes: uint8")
2066 .Input("mkl_filter: uint8")
2067 .Input("mkl_out_backprop: uint8")
2068 .Output("output: T")
2069 .Output("mkl_output: uint8")
2070 .Attr("T: {bfloat16, float}")
2071 .Attr("strides: list(int)")
2072 .Attr("use_cudnn_on_gpu: bool = true")
2073 .Attr(GetPaddingAttrString())
2074 .Attr(GetConvnetDataFormatAttrString())
2075 .Attr(GetExplicitPaddingsAttrString())
2076 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd652d02(InferenceContext* c) 2077 .SetShapeFn([](InferenceContext* c) {
2078 ShapeHandle s;
2079 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2080 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
2081 c->set_output(0, s);
2082 return Status::OK();
2083 })
2084 .Doc(R"doc(
2085 MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the
2086 gradients of convolution with respect to the input.
2087
2088 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2089 expected to invoke these operators.
2090 )doc");
2091
2092 REGISTER_OP("_MklNativeConv2DBackpropInput")
2093 .Input("input_sizes: int32")
2094 .Input("filter: T")
2095 .Input("out_backprop: T")
2096 .Output("output: T")
2097 .Attr("T: {bfloat16, float}")
2098 .Attr("strides: list(int)")
2099 .Attr("use_cudnn_on_gpu: bool = true")
2100 .Attr(GetPaddingAttrStringWithExplicit())
2101 .Attr(GetExplicitPaddingsAttrString())
2102 .Attr(GetConvnetDataFormatAttrString())
2103 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd652e02(InferenceContext* c) 2104 .SetShapeFn([](InferenceContext* c) {
2105 ShapeHandle s;
2106 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2107 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
2108 c->set_output(0, s);
2109 return Status::OK();
2110 })
2111 .Doc(R"doc(
2112 MKL version of Convolution2D backward input for Eager mode. Uses MKL DNN APIs
2113 to compute the gradients of convolution with respect to the input.
2114
2115 NOTE Do not invoke this operator directly in Python. Eager op rewrite is
2116 expected to invoke these operators.
2117 )doc");
2118
2119 REGISTER_OP("_MklConv3D")
2120 .Input("input: T")
2121 .Input("filter: T")
2122 .Input("mkl_input: uint8")
2123 .Input("mkl_filter: uint8")
2124 .Output("output: T")
2125 .Output("filter_output: T")
2126 .Output("mkl_output: uint8")
2127 .Output("mkl_filter_output: uint8")
2128 .Attr("T: {bfloat16, float}")
2129 .Attr("strides: list(int) >= 5")
2130 .Attr("is_filter_const: bool = false")
2131 .Attr(GetPaddingAttrString())
2132 .Attr(GetConvnet3dDataFormatAttrString())
2133 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
2134 .SetShapeFn(shape_inference::Conv3DShape)
2135 .Doc(R"doc(
2136 MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution.
2137
2138 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2139 expected to invoke these operators.
2140 )doc");
2141
2142 REGISTER_OP("_MklConv3DBackpropInputV2")
2143 .Input("input_sizes: Tshape")
2144 .Input("filter: T")
2145 .Input("out_backprop: T")
2146 .Input("mkl_input_sizes: uint8")
2147 .Input("mkl_filter: uint8")
2148 .Input("mkl_out_backprop: uint8")
2149 .Output("output: T")
2150 .Output("mkl_output: uint8")
2151 .Attr("T: {bfloat16, float}")
2152 .Attr("strides: list(int) >= 5")
2153 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
2154 .Attr("Tshape: {int32, int64} = DT_INT32")
2155 .Attr(GetPaddingAttrString())
2156 .Attr(GetConvnet3dDataFormatAttrString())
__anon84cbdd652f02(InferenceContext* c) 2157 .SetShapeFn([](InferenceContext* c) {
2158 ShapeHandle s;
2159 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2160 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
2161 c->set_output(0, s);
2162 return Status::OK();
2163 })
2164 .Doc(R"doc(
2165 MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the
2166 gradients of convolution with respect to the input.
2167
2168 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2169 expected to invoke these operators.
2170 )doc");
2171
2172 REGISTER_OP("_MklConv3DBackpropFilterV2")
2173 .Input("input: T")
2174 .Input("filter_sizes: int32")
2175 .Input("out_backprop: T")
2176 .Input("mkl_input: uint8")
2177 .Input("mkl_filter_size: uint8")
2178 .Input("mkl_out_backprop: uint8")
2179 .Output("output: T")
2180 .Output("mkl_output: uint8")
2181 .Attr("T: {bfloat16, float}")
2182 .Attr("strides: list(int)")
2183 .Attr(GetPaddingAttrString())
2184 .Attr(GetConvnet3dDataFormatAttrString())
2185 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon84cbdd653002(InferenceContext* c) 2186 .SetShapeFn([](InferenceContext* c) {
2187 ShapeHandle s;
2188 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
2189 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
2190 c->set_output(0, s);
2191 return Status::OK();
2192 })
2193 .Doc(R"doc(
2194 MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the
2195 gradients of convolution with respect to the filter.
2196
2197 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2198 expected to invoke these operators.
2199 )doc");
2200
2201 REGISTER_OP("_MklRelu")
2202 .Input("features: T")
2203 .Input("mkl_features: uint8")
2204 .Output("activations: T")
2205 .Output("mkl_activations: uint8")
2206 .Attr("T: {float, bfloat16} = DT_FLOAT")
2207 .SetShapeFn(shape_inference::UnchangedShape)
2208 .Doc(R"doc(
2209 MKL version of Relu operator. Uses MKL DNN APIs to implement Relu operator.
2210
2211 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2212 expected to invoke these operators.
2213 )doc");
2214
2215 REGISTER_OP("_MklReluGrad")
2216 .Input("gradients: T")
2217 .Input("features: T")
2218 .Input("mkl_gradients: uint8")
2219 .Input("mkl_features: uint8")
2220 .Output("backprops: T")
2221 .Output("mkl_backprops: uint8")
2222 .Attr("T: {float, bfloat16} = DT_FLOAT")
2223 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2224 .Doc(R"doc(
2225 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
2226 linear gradients for Relu operation.
2227
2228 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2229 expected to invoke these operators.
2230 )doc");
2231
2232 REGISTER_OP("_MklRelu6")
2233 .Input("features: T")
2234 .Input("mkl_features: uint8")
2235 .Output("activations: T")
2236 .Output("mkl_activations: uint8")
2237 .Attr("T: {float, bfloat16} = DT_FLOAT")
2238 .SetShapeFn(shape_inference::UnchangedShape)
2239 .Doc(R"doc(
2240 MKL version of Relu6 operator. Uses MKL DNN APIs to implement Relu6 operator.
2241
2242 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2243 expected to invoke these operators.
2244 )doc");
2245
2246 REGISTER_OP("_MklRelu6Grad")
2247 .Input("gradients: T")
2248 .Input("features: T")
2249 .Input("mkl_gradients: uint8")
2250 .Input("mkl_features: uint8")
2251 .Output("backprops: T")
2252 .Output("mkl_backprops: uint8")
2253 .Attr("T: {float, bfloat16} = DT_FLOAT")
2254 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2255 .Doc(R"doc(
2256 MKL version of Relu6Grad operator. Uses MKL DNN APIs to compute rectified
2257 linear gradients for Relu6 operation.
2258
2259 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2260 expected to invoke these operators.
2261 )doc");
2262
2263 REGISTER_OP("_MklLeakyRelu")
2264 .Input("features: T")
2265 .Input("mkl_features: uint8")
2266 .Output("activations: T")
2267 .Output("mkl_activations: uint8")
2268 .Attr("T: {float, bfloat16} = DT_FLOAT")
2269 .Attr("alpha: float = 0.2")
2270 .SetShapeFn(shape_inference::UnchangedShape)
2271 .Doc(R"doc(
2272 MKL version of LeakyRelu operator. Uses MKL DNN APIs to implement
2273 LeakyRelu operator.
2274
2275 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2276 expected to invoke these operators.
2277 )doc");
2278
2279 REGISTER_OP("_MklLeakyReluGrad")
2280 .Input("gradients: T")
2281 .Input("features: T")
2282 .Input("mkl_gradients: uint8")
2283 .Input("mkl_features: uint8")
2284 .Output("backprops: T")
2285 .Output("mkl_backprops: uint8")
2286 .Attr("T: {float, bfloat16} = DT_FLOAT")
2287 .Attr("alpha: float = 0.2")
2288 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2289 .Doc(R"doc(
2290 MKL version of LeakyReluGrad operator. Uses MKL DNN APIs to compute rectified
2291 linear gradients for LeakyReluGrad operation.
2292
2293 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2294 expected to invoke these operators.
2295 )doc");
2296
2297 REGISTER_OP("_MklElu")
2298 .Input("features: T")
2299 .Input("mkl_features: uint8")
2300 .Output("activations: T")
2301 .Output("mkl_activations: uint8")
2302 .Attr("T: {float, bfloat16} = DT_FLOAT")
2303 .SetShapeFn(shape_inference::UnchangedShape)
2304 .Doc(R"doc(
2305 MKL version of Elu operator. Uses MKL DNN APIs to implement Elu operator.
2306 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2307 expected to invoke these operators.
2308 )doc");
2309
2310 REGISTER_OP("_MklEluGrad")
2311 .Input("gradients: T")
2312 .Input("features: T")
2313 .Input("mkl_gradients: uint8")
2314 .Input("mkl_features: uint8")
2315 .Output("backprops: T")
2316 .Output("mkl_backprops: uint8")
2317 .Attr("T: {float, bfloat16} = DT_FLOAT")
2318 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2319 .Doc(R"doc(
2320 MKL version of EluGrad operator. Uses MKL DNN APIs to compute Elu
2321 gradients for Elu operation.
2322 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2323 expected to invoke these operators.
2324 )doc");
2325
2326 REGISTER_OP("_MklSoftmax")
2327 .Input("logits: T")
2328 .Input("mkl_logits: uint8")
2329 .Output("softmax: T")
2330 .Output("mkl_softmax: uint8")
2331 .Attr("T: {bfloat16, half, float, double}")
__anon84cbdd653102(InferenceContext* c) 2332 .SetShapeFn([](InferenceContext* c) {
2333 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
2334 })
2335 .Doc(R"doc(
2336 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
2337 linear gradients for Relu operation.
2338 )doc");
2339
2340 REGISTER_OP("_MklTanh")
2341 .Input("features: T")
2342 .Input("mkl_features: uint8")
2343 .Output("activations: T")
2344 .Output("mkl_activations: uint8")
2345 .Attr("T: realnumbertype")
2346 .SetShapeFn(shape_inference::UnchangedShape)
2347 .Doc(R"doc(
2348 MKL version of Tanh operator. Uses MKL DNN APIs to implement Tanh operator.
2349 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2350 expected to invoke these operators.
2351 )doc");
2352
2353 REGISTER_OP("_MklTanhGrad")
2354 .Input("gradients: T")
2355 .Input("features: T")
2356 .Input("mkl_gradients: uint8")
2357 .Input("mkl_features: uint8")
2358 .Output("backprops: T")
2359 .Output("mkl_backprops: uint8")
2360 .Attr("T: realnumbertype")
2361 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2362 .Doc(R"doc(
2363 MKL version of TanhGrad operator. Uses MKL DNN APIs to compute tanh
2364 gradients for Tanh operation.
2365 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2366 expected to invoke these operators.
2367 )doc");
2368
2369 REGISTER_OP("_MklMaxPool")
2370 .Attr("T: {float, half, bfloat16} = DT_FLOAT")
2371 .Attr("ksize: list(int) >= 4")
2372 .Attr("strides: list(int) >= 4")
2373 .Attr(GetPaddingAttrString())
2374 .Attr(GetConvnetDataFormatAttrString())
2375 .Attr(GetExplicitPaddingsAttrString())
2376 .Attr("workspace_enabled: bool = false")
2377 .Input("input: T")
2378 .Input("mkl_input: uint8")
2379 .Output("output: T")
2380 #ifdef INTEL_MKL_ML_ONLY
2381 .Output("workspace: T")
2382 #else
2383 .Output("workspace: uint8")
2384 #endif
2385 .Output("mkl_output: uint8")
2386 .Output("mkl_workspace: uint8")
2387 .SetShapeFn(shape_inference::MaxPoolShape)
2388 .Doc(R"doc(
2389 MKL version of MaxPool operator. Uses MKL DNN APIs to perform max pooling
2390 on the input.
2391
2392 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2393 expected to invoke these operators.
2394 )doc");
2395
2396 REGISTER_OP("_MklMaxPoolGrad")
2397 .Attr("T: {float, half, bfloat16} = DT_FLOAT")
2398 .Attr("ksize: list(int) >= 4")
2399 .Attr("strides: list(int) >= 4")
2400 .Attr("workspace_enabled: bool = false")
2401 .Attr(GetPaddingAttrString())
2402 .Attr(GetConvnetDataFormatAttrString())
2403 .Attr(GetExplicitPaddingsAttrString())
2404 .Input("orig_input: T")
2405 .Input("orig_output: T")
2406 .Input("grad: T")
2407 #ifdef INTEL_MKL_ML_ONLY
2408 .Input("workspace: T")
2409 #else
2410 .Input("workspace: uint8")
2411 #endif
2412 .Input("mkl_orig_input: uint8")
2413 .Input("mkl_orig_output: uint8")
2414 .Input("mkl_grad: uint8")
2415 .Input("mkl_workspace: uint8")
2416 .Output("output: T")
2417 .Output("mkl_output: uint8")
2418 .SetShapeFn(shape_inference::MaxPoolGradShape)
2419 .Doc(R"doc(
2420 oneDNN version of MaxPoolGrad. Uses oneDNN APIs to compute gradients of
2421 MaxPool operator.
2422
2423 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2424 expected to invoke these operators.
2425 )doc");
2426
2427 REGISTER_OP("_MklAvgPool")
2428 .Input("value: T")
2429 .Input("mkl_input: uint8")
2430 .Output("output: T")
2431 .Output("mkl_output: uint8")
2432 .Attr("ksize: list(int) >= 4")
2433 .Attr("strides: list(int) >= 4")
2434 .Attr(GetPaddingAttrString())
2435 .Attr(GetConvnetDataFormatAttrString())
2436 .Attr("T: {float, half, double, bfloat16}")
2437 .SetShapeFn(shape_inference::AvgPoolShape)
2438 .Doc(R"doc(
2439 MKL version of AvgPool operator. Uses MKL DNN APIs to perform average pooling
2440 on the input.
2441
2442 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2443 expected to invoke these operators.
2444 )doc");
2445
2446 REGISTER_OP("_MklAvgPoolGrad")
2447 .Input("orig_input_shape: int32")
2448 .Input("grad: T")
2449 .Input("mkl_orig_input: uint8")
2450 .Input("mkl_grad: uint8")
2451 .Output("output: T")
2452 .Output("mkl_output: uint8")
2453 .Attr("ksize: list(int) >= 4")
2454 .Attr("strides: list(int) >= 4")
2455 .Attr(GetPaddingAttrString())
2456 .Attr(GetConvnetDataFormatAttrString())
2457 .Attr("T: {float, half, double, bfloat16}")
2458 .SetShapeFn(shape_inference::AvgPoolGradShape)
2459 .Doc(R"doc(
2460 oneDNN version of AvgPoolGrad operator. Uses oneDNN APIs to compute gradients
2461 of AvgPool function.
2462
2463 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2464 expected to invoke these operators.
2465 )doc");
2466
2467 REGISTER_OP("_MklAvgPool3D")
2468 .Input("value: T")
2469 .Input("mkl_input: uint8")
2470 .Output("output: T")
2471 .Output("mkl_output: uint8")
2472 .Attr("ksize: list(int) >= 5")
2473 .Attr("strides: list(int) >= 5")
2474 .Attr(GetPaddingAttrString())
2475 .Attr(GetConvnet3dDataFormatAttrString())
2476 .Attr("T: {float, half, double, bfloat16}")
2477 .SetShapeFn(shape_inference::Pool3DShape)
2478 .Doc(R"doc(
2479 MKL version of AvgPool3D operator. Uses MKL DNN APIs to perform average pooling
2480 on the input.
2481
2482 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2483 expected to invoke these operators.
2484 )doc");
2485
2486 REGISTER_OP("_MklAvgPool3DGrad")
2487 .Input("orig_input_shape: int32")
2488 .Input("grad: T")
2489 .Input("mkl_orig_input: uint8")
2490 .Input("mkl_grad: uint8")
2491 .Output("output: T")
2492 .Output("mkl_output: uint8")
2493 .Attr("ksize: list(int) >= 5")
2494 .Attr("strides: list(int) >= 5")
2495 .Attr(GetPaddingAttrString())
2496 .Attr(GetConvnet3dDataFormatAttrString())
2497 .Attr("T: {float, half, double, bfloat16}")
2498 .SetShapeFn(shape_inference::AvgPool3DGradShape)
2499 .Doc(R"doc(
2500 oneDNN version of AvgPool3DGrad operator. Uses oneDNN APIs to compute gradients
2501 of AvgPool function.
2502
2503 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2504 expected to invoke these operators.
2505 )doc");
2506
2507 REGISTER_OP("_MklMaxPool3D")
2508 .Input("input: T")
2509 .Input("mkl_input: uint8")
2510 .Output("output: T")
2511 .Output("workspace: uint8")
2512 .Output("mkl_output: uint8")
2513 .Output("mkl_workspace: uint8")
2514 .Attr("ksize: list(int) >= 5")
2515 .Attr("strides: list(int) >= 5")
2516 .Attr(GetPaddingAttrString())
2517 .Attr(GetConvnet3dDataFormatAttrString())
2518 .Attr("T: {half, bfloat16, float}")
2519 .Attr("workspace_enabled: bool = false")
2520 .SetShapeFn(shape_inference::Pool3DShape)
2521 .Doc(R"doc(
2522 MKL version of MaxPool3D operator. Uses MKL DNN APIs to perform average pooling
2523 on the input.
2524
2525 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2526 expected to invoke these operators.
2527 )doc");
2528
2529 REGISTER_OP("_MklMaxPool3DGrad")
2530 .Input("orig_input: TInput")
2531 .Input("orig_output: TInput")
2532 .Input("grad: T")
2533 .Input("workspace: uint8")
2534 .Input("mkl_orig_input: uint8")
2535 .Input("mkl_orig_output: uint8")
2536 .Input("mkl_grad: uint8")
2537 .Input("mkl_workspace: uint8")
2538 .Output("output: T")
2539 .Output("mkl_output: uint8")
2540 .Attr("ksize: list(int) >= 5")
2541 .Attr("strides: list(int) >= 5")
2542 .Attr(GetPaddingAttrString())
2543 .Attr(GetConvnet3dDataFormatAttrString())
2544 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
2545 .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
2546 .Attr("workspace_enabled: bool = false")
2547 .SetShapeFn(shape_inference::MaxPool3DGradShape)
2548 .Doc(R"doc(
2549 oneDNN version of MaxPool3DGrad operator. Uses oneDNN APIs to compute gradients
2550 of MaxPool3D function.
2551
2552 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2553 expected to invoke these operators.
2554 )doc");
2555
2556 REGISTER_OP("_MklLRN")
2557 .Input("input: T")
2558 .Input("mkl_input: uint8")
2559 .Output("output: T")
2560 .Output("workspace: uint8")
2561 .Output("mkl_output: uint8")
2562 .Output("mkl_workspace: uint8")
2563 .Attr("depth_radius: int = 5")
2564 .Attr("bias: float = 1.0")
2565 .Attr("alpha: float = 1.0")
2566 .Attr("beta: float = 0.5")
2567 .Attr("workspace_enabled: bool = false")
2568 .Attr("T: {float, half} = DT_FLOAT")
__anon84cbdd653202(InferenceContext* c) 2569 .SetShapeFn([](InferenceContext* c) {
2570 return UnchangedShapeWithRank(c, 4);
2571 })
2572 .Doc(R"doc(
2573 MKL version of LRN operator. Uses MKL DNN APIs to perform local response
2574 normalization.
2575
2576 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2577 expected to invoke these operators.
2578 )doc");
2579
2580 REGISTER_OP("_MklLRNGrad")
2581 .Input("input_grads: T")
2582 .Input("input_image: T")
2583 .Input("output_image: T")
2584 .Input("workspace: uint8")
2585 .Input("mkl_input_grads: uint8")
2586 .Input("mkl_input_image: uint8")
2587 .Input("mkl_output_image: uint8")
2588 .Input("mkl_workspace: uint8")
2589 .Output("output: T")
2590 .Output("mkl_output: uint8")
2591 .Attr("depth_radius: int = 5")
2592 .Attr("bias: float = 1.0")
2593 .Attr("alpha: float = 1.0")
2594 .Attr("beta: float = 0.5")
2595 .Attr("workspace_enabled: bool = false")
2596 .Attr("T: {float, half} = DT_FLOAT")
__anon84cbdd653302(InferenceContext* c) 2597 .SetShapeFn([](InferenceContext* c) {
2598 ShapeHandle s;
2599 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
2600 TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image
2601 TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image
2602 c->set_output(0, s);
2603 return Status::OK();
2604 })
2605 .Doc(R"doc(
2606 MKL version of LRNGrad operator. Uses MKL DNN APIs to compute gradient for
2607 local response normalization.
2608
2609 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2610 expected to invoke these operators.
2611 )doc");
2612
2613 REGISTER_OP("_MklFusedBatchNorm")
2614 .Input("x: T")
2615 .Input("scale: T")
2616 .Input("offset: T")
2617 .Input("mean: T")
2618 .Input("variance: T")
2619 .Input("mkl_x: uint8")
2620 .Input("mkl_scale: uint8")
2621 .Input("mkl_offset: uint8")
2622 .Input("mkl_mean: uint8")
2623 .Input("mkl_variance: uint8")
2624 .Output("y: T")
2625 .Output("batch_mean: T")
2626 .Output("batch_variance: T")
2627 .Output("reserve_space_1: T")
2628 .Output("reserve_space_2: T")
2629 .Output("mkl_y: uint8")
2630 .Output("mkl_batch_mean: uint8")
2631 .Output("mkl_batch_variance: uint8")
2632 .Output("mkl_reserve_space_1: uint8")
2633 .Output("mkl_reserve_space_2: uint8")
2634 .Attr("T: numbertype")
2635 .Attr("epsilon: float = 0.0001")
2636 .Attr("data_format: string = 'NHWC'")
2637 .Attr("exponential_avg_factor: float = 1.0")
2638 .Attr("is_training: bool = true")
2639 .SetShapeFn(shape_inference::FusedBatchNormShape)
2640 .Doc(R"doc(
2641 oneDNN version of FusedBatchNorm operator. Uses oneDNN APIs to perform fused
2642 batch normalization.
2643
2644 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2645 expected to invoke these operators.
2646 )doc");
2647
2648 REGISTER_OP("_MklFusedBatchNormGrad")
2649 .Input("y_backprop: T")
2650 .Input("x: T")
2651 .Input("scale: T")
2652 .Input("reserve_space_1: T")
2653 .Input("reserve_space_2: T")
2654 .Input("mkl_y_backprop: uint8")
2655 .Input("mkl_x: uint8")
2656 .Input("mkl_scale: uint8")
2657 .Input("mkl_reserve_space_1: uint8")
2658 .Input("mkl_reserve_space_2: uint8")
2659 .Output("x_backprop: T")
2660 .Output("scale_backprop: T")
2661 .Output("offset_backprop: T")
2662 .Output("reserve_space_3: T")
2663 .Output("reserve_space_4: T")
2664 .Output("mkl_x_backprop: uint8")
2665 .Output("mkl_scale_backprop: uint8")
2666 .Output("mkl_offset_backprop: uint8")
2667 .Output("mkl_reserve_space_3: uint8")
2668 .Output("mkl_reserve_space_4: uint8")
2669 .Attr("T: numbertype")
2670 .Attr("epsilon: float = 0.0001")
2671 .Attr("data_format: string = 'NHWC'")
2672 .Attr("is_training: bool = true")
2673 .SetShapeFn(shape_inference::FusedBatchNormGradShape)
2674 .Doc(R"doc(
2675 oneDNN version of FusedBatchNormGrad operator. Uses oneDNN APIs to compute
2676 gradients for fused batch normalization.
2677
2678 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2679 expected to invoke these operators.
2680 )doc");
2681
2682 REGISTER_OP("_MklFusedBatchNormV2")
2683 .Input("x: T")
2684 .Input("scale: U")
2685 .Input("offset: U")
2686 .Input("mean: U")
2687 .Input("variance: U")
2688 .Input("mkl_x: uint8")
2689 .Input("mkl_scale: uint8")
2690 .Input("mkl_offset: uint8")
2691 .Input("mkl_mean: uint8")
2692 .Input("mkl_variance: uint8")
2693 .Output("y: T")
2694 .Output("batch_mean: U")
2695 .Output("batch_variance: U")
2696 .Output("reserve_space_1: U")
2697 .Output("reserve_space_2: U")
2698 .Output("mkl_y: uint8")
2699 .Output("mkl_batch_mean: uint8")
2700 .Output("mkl_batch_variance: uint8")
2701 .Output("mkl_reserve_space_1: uint8")
2702 .Output("mkl_reserve_space_2: uint8")
2703 .Attr("T: {bfloat16, float}")
2704 .Attr("U: {float}")
2705 .Attr("epsilon: float = 0.0001")
2706 .Attr(GetConvnetDataFormatAttrString())
2707 .Attr("exponential_avg_factor: float = 1.0")
2708 .Attr("is_training: bool = true")
2709 .SetShapeFn(shape_inference::FusedBatchNormShape);
2710
2711 REGISTER_OP("_MklFusedBatchNormGradV2")
2712 .Input("y_backprop: T")
2713 .Input("x: T")
2714 .Input("scale: float")
2715 .Input("reserve_space_1: U")
2716 .Input("reserve_space_2: U")
2717 .Input("mkl_y_backprop: uint8")
2718 .Input("mkl_x: uint8")
2719 .Input("mkl_scale: uint8")
2720 .Input("mkl_reserve_space_1: uint8")
2721 .Input("mkl_reserve_space_2: uint8")
2722 .Output("x_backprop: T")
2723 .Output("scale_backprop: U")
2724 .Output("offset_backprop: U")
2725 .Output("reserve_space_3: U")
2726 .Output("reserve_space_4: U")
2727 .Output("mkl_x_backprop: uint8")
2728 .Output("mkl_scale_backprop: uint8")
2729 .Output("mkl_offset_backprop: uint8")
2730 .Output("mkl_reserve_space_3: uint8")
2731 .Output("mkl_reserve_space_4: uint8")
2732 .Attr("T: {bfloat16, float}")
2733 .Attr("U: {float}")
2734 .Attr("epsilon: float = 0.0001")
2735 .Attr(GetConvnetDataFormatAttrString())
2736 .Attr("is_training: bool = true")
2737 .SetShapeFn(shape_inference::FusedBatchNormGradShape);
2738
2739 REGISTER_OP("_MklToTf")
2740 .Input("input: T")
2741 .Input("mkl_input: uint8")
2742 .Output("output: T")
2743 .Attr("T: {half, float, double, bfloat16, qint8, quint8, qint32}")
2744 .Attr(GetConvnetDataFormat2D3DAttrString())
2745 .SetShapeFn(shape_inference::UnknownShape)
2746 .Doc(R"doc(
2747 MKL operator to convert a tensor from MKL layout to TensorFlow layout.
2748
2749 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2750 expected to invoke these operators.
2751 )doc");
2752
2753 REGISTER_OP("_MklInputConversion")
2754 .Input("input_0: T")
2755 .Input("input_1: T")
2756 .Input("mkl_input_0: uint8")
2757 .Input("mkl_input_1: uint8")
2758 .Output("output_0: T")
2759 .Output("output_1: T")
2760 .Output("mkl_output_0: uint8")
2761 .Output("mkl_output_1: uint8")
2762 // All datatypes supported by element-wise ops
2763 .Attr(
2764 "T: {half, float, bfloat16, double, uint8, int8, uint16, int16, int32, "
2765 "int64, complex64, complex128}")
2766 .Attr(GetConvnetDataFormat2D3DAttrString())
2767 .SetShapeFn(shape_inference::UnknownShape)
2768 .Doc(R"doc(
2769 MKL operator to process the inputs to an elementwise MKL op. Both inputs
2770 need to be either in TF or in MKL format. This op is added before every
2771 element-wise MKL op.
2772
2773 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2774 expected to invoke these operators.
2775 )doc");
2776
2777 #endif // INTEL_MKL
2778 REGISTER_OP("QuantizedConv2DAndRequantize")
2779 .Input("input: Tinput")
2780 .Input("filter: Tfilter")
2781 .Input("min_input: float")
2782 .Input("max_input: float")
2783 .Input("min_filter: float")
2784 .Input("max_filter: float")
2785 .Input("min_freezed_output: float")
2786 .Input("max_freezed_output: float")
2787 .Output("output: out_type")
2788 .Output("min_output: float")
2789 .Output("max_output: float")
2790 .Attr("Tinput: quantizedtype")
2791 .Attr("Tfilter: quantizedtype")
2792 .Attr("out_type: quantizedtype = DT_QINT8")
2793 .Attr("strides: list(int)")
2794 .Attr(GetPaddingAttrString())
2795 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2796 .Attr("padding_list: list(int) = []")
__anon84cbdd653402(InferenceContext* c) 2797 .SetShapeFn([](InferenceContext* c) {
2798 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2799 ShapeHandle unused;
2800 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2801 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2802 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2803 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2804 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2805 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2806 c->set_output(1, c->Scalar());
2807 c->set_output(2, c->Scalar());
2808 return OkStatus();
2809 });
2810
2811 // Fusion of Quantized Conv2D and BiasAdd.
2812 REGISTER_OP("QuantizedConv2DWithBias")
2813 .Input("input: Tinput")
2814 .Input("filter: Tfilter")
2815 .Input("bias: float")
2816 .Input("min_input: float")
2817 .Input("max_input: float")
2818 .Input("min_filter: float")
2819 .Input("max_filter: float")
2820 .Output("output: out_type")
2821 .Output("min_output: float")
2822 .Output("max_output: float")
2823 .Attr("Tinput: quantizedtype")
2824 .Attr("Tfilter: quantizedtype")
2825 .Attr("out_type: quantizedtype = DT_QINT32")
2826 .Attr("strides: list(int)")
2827 .Attr(GetPaddingAttrString())
2828 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2829 .Attr("padding_list: list(int) = []")
__anon84cbdd653502(InferenceContext* c) 2830 .SetShapeFn([](InferenceContext* c) {
2831 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2832 ShapeHandle unused, channel;
2833 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2834 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2835 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2836 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2837 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
2838 c->set_output(1, channel);
2839 c->set_output(2, channel);
2840 return OkStatus();
2841 });
2842
2843 REGISTER_OP("QuantizedConv2DWithBiasAndRequantize")
2844 .Input("input: Tinput")
2845 .Input("filter: Tfilter")
2846 .Input("bias: Tbias")
2847 .Input("min_input: float")
2848 .Input("max_input: float")
2849 .Input("min_filter: float")
2850 .Input("max_filter: float")
2851 .Input("min_freezed_output: float")
2852 .Input("max_freezed_output: float")
2853 .Output("output: out_type")
2854 .Output("min_output: float")
2855 .Output("max_output: float")
2856 .Attr("Tinput: quantizedtype")
2857 .Attr("Tfilter: quantizedtype")
2858 .Attr("Tbias: {float, qint32}")
2859 .Attr("out_type: quantizedtype = DT_QINT8")
2860 .Attr("strides: list(int)")
2861 .Attr(GetPaddingAttrString())
2862 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2863 .Attr("padding_list: list(int) = []")
__anon84cbdd653602(InferenceContext* c) 2864 .SetShapeFn([](InferenceContext* c) {
2865 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2866 ShapeHandle unused, channel;
2867 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2868 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2869 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2870 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2871 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
2872 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2873 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
2874 c->set_output(1, c->Scalar());
2875 c->set_output(2, c->Scalar());
2876 return OkStatus();
2877 });
2878
2879 // Fusion of Quantized Conv2D and Relu.
2880 REGISTER_OP("QuantizedConv2DAndRelu")
2881 .Input("input: Tinput")
2882 .Input("filter: Tfilter")
2883 .Input("min_input: float")
2884 .Input("max_input: float")
2885 .Input("min_filter: float")
2886 .Input("max_filter: float")
2887 .Output("output: out_type")
2888 .Output("min_output: float")
2889 .Output("max_output: float")
2890 .Attr("Tinput: quantizedtype")
2891 .Attr("Tfilter: quantizedtype")
2892 .Attr("out_type: quantizedtype = DT_QINT32")
2893 .Attr("strides: list(int)")
2894 .Attr(GetPaddingAttrString())
2895 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2896 .Attr("padding_list: list(int) = []")
__anon84cbdd653702(InferenceContext* c) 2897 .SetShapeFn([](InferenceContext* c) {
2898 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2899 ShapeHandle unused, channel;
2900 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2901 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2902 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel));
2903 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2904 c->set_output(1, channel);
2905 c->set_output(2, channel);
2906 return OkStatus();
2907 });
2908
2909 REGISTER_OP("QuantizedConv2DAndReluAndRequantize")
2910 .Input("input: Tinput")
2911 .Input("filter: Tfilter")
2912 .Input("min_input: float")
2913 .Input("max_input: float")
2914 .Input("min_filter: float")
2915 .Input("max_filter: float")
2916 .Input("min_freezed_output: float")
2917 .Input("max_freezed_output: float")
2918 .Output("output: out_type")
2919 .Output("min_output: float")
2920 .Output("max_output: float")
2921 .Attr("Tinput: quantizedtype")
2922 .Attr("Tfilter: quantizedtype")
2923 .Attr("out_type: quantizedtype = DT_QUINT8")
2924 .Attr("strides: list(int)")
2925 .Attr(GetPaddingAttrString())
2926 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2927 .Attr("padding_list: list(int) = []")
__anon84cbdd653802(InferenceContext* c) 2928 .SetShapeFn([](InferenceContext* c) {
2929 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2930 ShapeHandle unused, channel;
2931 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2932 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2933 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel));
2934 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2935 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2936 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2937 c->set_output(1, c->Scalar());
2938 c->set_output(2, c->Scalar());
2939 return OkStatus();
2940 });
2941
2942 // Fusion of Quantized Conv2D, BiasAdd and Relu.
2943 REGISTER_OP("QuantizedConv2DWithBiasAndRelu")
2944 .Input("input: Tinput")
2945 .Input("filter: Tfilter")
2946 .Input("bias: float")
2947 .Input("min_input: float")
2948 .Input("max_input: float")
2949 .Input("min_filter: float")
2950 .Input("max_filter: float")
2951 .Output("output: out_type")
2952 .Output("min_output: float")
2953 .Output("max_output: float")
2954 .Attr("Tinput: quantizedtype")
2955 .Attr("Tfilter: quantizedtype")
2956 .Attr("out_type: quantizedtype = DT_QINT32")
2957 .Attr("strides: list(int)")
2958 .Attr(GetPaddingAttrString())
2959 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2960 .Attr("padding_list: list(int) = []")
__anon84cbdd653902(InferenceContext* c) 2961 .SetShapeFn([](InferenceContext* c) {
2962 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2963 ShapeHandle unused, channel;
2964 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2965 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2966 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2967 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2968 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
2969 c->set_output(1, channel);
2970 c->set_output(2, channel);
2971 return OkStatus();
2972 });
2973
2974 // Fusion of Quantized Conv2D, BiasAdd, Relu, and Requantize.
2975 REGISTER_OP("QuantizedConv2DWithBiasAndReluAndRequantize")
2976 .Input("input: Tinput")
2977 .Input("filter: Tfilter")
2978 .Input("bias: Tbias")
2979 .Input("min_input: float")
2980 .Input("max_input: float")
2981 .Input("min_filter: float")
2982 .Input("max_filter: float")
2983 .Input("min_freezed_output: float")
2984 .Input("max_freezed_output: float")
2985 .Output("output: out_type")
2986 .Output("min_output: float")
2987 .Output("max_output: float")
2988 .Attr("Tinput: quantizedtype")
2989 .Attr("Tfilter: quantizedtype")
2990 .Attr("Tbias: {float, qint32}")
2991 .Attr("out_type: quantizedtype = DT_QUINT8")
2992 .Attr("strides: list(int)")
2993 .Attr(GetPaddingAttrString())
2994 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2995 .Attr("padding_list: list(int) = []")
__anon84cbdd653a02(InferenceContext* c) 2996 .SetShapeFn([](InferenceContext* c) {
2997 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2998 ShapeHandle unused, channel;
2999 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3000 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3001 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3002 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3003 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3004 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3005 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3006 c->set_output(1, c->Scalar());
3007 c->set_output(2, c->Scalar());
3008 return OkStatus();
3009 });
3010
3011 // Fusion of Quantized Conv2D, BiasAdd, Sum, and Relu.
3012 REGISTER_OP("QuantizedConv2DWithBiasSumAndRelu")
3013 .Input("input: Tinput")
3014 .Input("filter: Tfilter")
3015 .Input("bias: float")
3016 .Input("min_input: float")
3017 .Input("max_input: float")
3018 .Input("min_filter: float")
3019 .Input("max_filter: float")
3020 .Input("summand: float")
3021 .Output("output: out_type")
3022 .Output("min_output: float")
3023 .Output("max_output: float")
3024 .Attr("Tinput: quantizedtype")
3025 .Attr("Tfilter: quantizedtype")
3026 .Attr("out_type: quantizedtype = DT_QINT32")
3027 .Attr("strides: list(int)")
3028 .Attr(GetPaddingAttrString())
3029 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3030 .Attr("padding_list: list(int) = []")
__anon84cbdd653b02(InferenceContext* c) 3031 .SetShapeFn([](InferenceContext* c) {
3032 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3033 ShapeHandle unused, channel;
3034 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3035 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3036 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3037 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3038 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3039 c->set_output(1, channel);
3040 c->set_output(2, channel);
3041 return OkStatus();
3042 });
3043
3044 REGISTER_OP("QuantizedConv2DWithBiasSumAndReluAndRequantize")
3045 .Input("input: Tinput")
3046 .Input("filter: Tfilter")
3047 .Input("bias: Tbias")
3048 .Input("min_input: float")
3049 .Input("max_input: float")
3050 .Input("min_filter: float")
3051 .Input("max_filter: float")
3052 .Input("min_freezed_output: float")
3053 .Input("max_freezed_output: float")
3054 .Input("summand: Tsummand")
3055 .Input("min_summand: float")
3056 .Input("max_summand: float")
3057 .Output("output: out_type")
3058 .Output("min_output: float")
3059 .Output("max_output: float")
3060 .Attr("Tinput: quantizedtype")
3061 .Attr("Tfilter: quantizedtype")
3062 .Attr("Tbias: {float, qint32}")
3063 .Attr("Tsummand: quantizedtype")
3064 .Attr("out_type: quantizedtype = DT_QUINT8")
3065 .Attr("strides: list(int)")
3066 .Attr(GetPaddingAttrString())
3067 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3068 .Attr("padding_list: list(int) = []")
__anon84cbdd653c02(InferenceContext* c) 3069 .SetShapeFn([](InferenceContext* c) {
3070 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3071 ShapeHandle unused, channel;
3072 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3073 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3074 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3075 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3076 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3077 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3078 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3079 c->set_output(1, c->Scalar());
3080 c->set_output(2, c->Scalar());
3081 return OkStatus();
3082 });
3083
3084 REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
3085 .Input("input: Tinput")
3086 .Input("filter: Tfilter")
3087 .Input("bias: Tbias")
3088 .Input("min_input: float")
3089 .Input("max_input: float")
3090 .Input("min_filter: float")
3091 .Input("max_filter: float")
3092 .Input("min_freezed_output: float")
3093 .Input("max_freezed_output: float")
3094 .Input("summand: Tsummand")
3095 .Input("min_summand: float")
3096 .Input("max_summand: float")
3097 .Output("output: out_type")
3098 .Output("min_output: float")
3099 .Output("max_output: float")
3100 .Attr("Tinput: quantizedtype")
3101 .Attr("Tfilter: quantizedtype")
3102 .Attr("Tbias: {float, qint32}")
3103 .Attr("Tsummand: quantizedtype")
3104 .Attr("out_type: quantizedtype = DT_QUINT8")
3105 .Attr("strides: list(int)")
3106 .Attr(GetPaddingAttrString())
3107 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3108 .Attr("padding_list: list(int) = []")
__anon84cbdd653d02(InferenceContext* c) 3109 .SetShapeFn([](InferenceContext* c) {
3110 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3111 ShapeHandle unused, channel;
3112 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3113 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3114 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3115 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3116 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3117 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3118 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3119 // Since activations are not requantized per channel, `min_output`
3120 // and `max_output` are scalars.
3121 c->set_output(1, c->Scalar());
3122 c->set_output(2, c->Scalar());
3123 return OkStatus();
3124 });
3125
3126 // Fusion of Quantized MatMul and BiasAdd.
3127 REGISTER_OP("QuantizedMatMulWithBias")
3128 .Input("a: T1")
3129 .Input("b: T2")
3130 .Input("bias: Tbias")
3131 .Input("min_a: float")
3132 .Input("max_a: float")
3133 .Input("min_b: float")
3134 .Input("max_b: float")
3135 .Output("out: Toutput")
3136 .Output("min_out: float")
3137 .Output("max_out: float")
3138 .Attr("T1: quantizedtype")
3139 .Attr("T2: quantizedtype")
3140 .Attr("Tbias: {float, qint32}")
3141 .Attr("Toutput: quantizedtype = DT_QINT32")
3142 .Attr("transpose_a: bool = false")
3143 .Attr("transpose_b: bool = false")
3144 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
__anon84cbdd653e02(InferenceContext* c) 3145 .SetShapeFn([](InferenceContext* c) {
3146 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3147 ShapeHandle unused;
3148 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3149 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3150 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3151 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3152 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3153 c->set_output(1, c->Scalar());
3154 c->set_output(2, c->Scalar());
3155 return OkStatus();
3156 });
3157
3158 REGISTER_OP("QuantizedMatMulWithBiasAndRelu")
3159 .Input("a: T1")
3160 .Input("b: T2")
3161 .Input("bias: float")
3162 .Input("min_a: float")
3163 .Input("max_a: float")
3164 .Input("min_b: float")
3165 .Input("max_b: float")
3166 .Output("out: Toutput")
3167 .Output("min_out: float")
3168 .Output("max_out: float")
3169 .Attr("T1: quantizedtype")
3170 .Attr("T2: quantizedtype")
3171 .Attr("Toutput: quantizedtype = DT_QINT32")
3172 .Attr("transpose_a: bool = false")
3173 .Attr("transpose_b: bool = false")
3174 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
__anon84cbdd653f02(InferenceContext* c) 3175 .SetShapeFn([](InferenceContext* c) {
3176 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3177 ShapeHandle unused;
3178 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3179 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3180 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3181 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3182 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3183 c->set_output(1, c->Scalar());
3184 c->set_output(2, c->Scalar());
3185 return OkStatus();
3186 });
3187
3188 REGISTER_OP("QuantizedMatMulWithBiasAndReluAndRequantize")
3189 .Input("a: T1")
3190 .Input("b: T2")
3191 .Input("bias: Tbias")
3192 .Input("min_a: float")
3193 .Input("max_a: float")
3194 .Input("min_b: float")
3195 .Input("max_b: float")
3196 .Input("min_freezed_output: float")
3197 .Input("max_freezed_output: float")
3198 .Output("out: Toutput")
3199 .Output("min_out: float")
3200 .Output("max_out: float")
3201 .Attr("T1: quantizedtype")
3202 .Attr("T2: quantizedtype")
3203 .Attr("Tbias: {float, qint32}")
3204 .Attr("Toutput: quantizedtype = DT_QUINT8")
3205 .Attr("transpose_a: bool = false")
3206 .Attr("transpose_b: bool = false")
3207 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
__anon84cbdd654002(InferenceContext* c) 3208 .SetShapeFn([](InferenceContext* c) {
3209 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3210 ShapeHandle unused;
3211 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3212 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3213 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3214 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3215 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3216 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3217 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3218 c->set_output(1, c->Scalar());
3219 c->set_output(2, c->Scalar());
3220 return OkStatus();
3221 });
3222
3223 REGISTER_OP("QuantizedMatMulWithBiasAndDequantize")
3224 .Input("a: T1")
3225 .Input("b: T2")
3226 .Input("bias: Tbias")
3227 .Input("min_a: float")
3228 .Input("max_a: float")
3229 .Input("min_b: float")
3230 .Input("max_b: float")
3231 .Input("min_freezed_output: float")
3232 .Input("max_freezed_output: float")
3233 .Output("out: Toutput")
3234 .Attr("T1: quantizedtype")
3235 .Attr("T2: quantizedtype")
3236 .Attr("Tbias: {float, qint32}")
3237 .Attr("Toutput: {float}")
3238 .Attr("transpose_a: bool = false")
3239 .Attr("transpose_b: bool = false")
3240 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
__anon84cbdd654102(InferenceContext* c) 3241 .SetShapeFn([](InferenceContext* c) {
3242 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3243 ShapeHandle unused;
3244 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3245 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3246 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3247 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3248 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3249 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3250 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3251
3252 return OkStatus();
3253 });
3254
3255 REGISTER_OP("QuantizedMatMulWithBiasAndRequantize")
3256 .Input("a: T1")
3257 .Input("b: T2")
3258 .Input("bias: Tbias")
3259 .Input("min_a: float")
3260 .Input("max_a: float")
3261 .Input("min_b: float")
3262 .Input("max_b: float")
3263 .Input("min_freezed_output: float")
3264 .Input("max_freezed_output: float")
3265 .Output("out: Toutput")
3266 .Output("min_out: float")
3267 .Output("max_out: float")
3268 .Attr("T1: quantizedtype")
3269 .Attr("T2: quantizedtype")
3270 .Attr("Tbias: {float, qint32}")
3271 .Attr("Toutput: quantizedtype = DT_QUINT8")
3272 .Attr("transpose_a: bool = false")
3273 .Attr("transpose_b: bool = false")
3274 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
__anon84cbdd654202(InferenceContext* c) 3275 .SetShapeFn([](InferenceContext* c) {
3276 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3277 ShapeHandle unused;
3278 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3279 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3280 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3281 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3282 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3283 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3284 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3285 c->set_output(1, c->Scalar());
3286 c->set_output(2, c->Scalar());
3287 return OkStatus();
3288 });
3289
3290 REGISTER_OP("QuantizedConv2DPerChannel")
3291 .Input("input: Tinput")
3292 .Input("filter: Tfilter")
3293 .Input("min_input: float")
3294 .Input("max_input: float")
3295 .Input("min_filter: float")
3296 .Input("max_filter: float")
3297 .Output("output: out_type")
3298 .Output("min_output: float")
3299 .Output("max_output: float")
3300 .Attr("Tinput: quantizedtype")
3301 .Attr("Tfilter: quantizedtype")
3302 .Attr("out_type: quantizedtype = DT_QINT32")
3303 .Attr("strides: list(int)")
3304 .Attr(GetPaddingAttrString())
3305 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon84cbdd654302(InferenceContext* c) 3306 .SetShapeFn([](InferenceContext* c) {
3307 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3308 ShapeHandle unused, channel;
3309 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3310 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3311 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel));
3312 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3313 c->set_output(1, channel);
3314 c->set_output(2, channel);
3315 return OkStatus();
3316 });
3317
3318 REGISTER_OP("QuantizedDepthwiseConv2D")
3319 .Input("input: Tinput")
3320 .Input("filter: Tfilter")
3321 .Input("min_input: float")
3322 .Input("max_input: float")
3323 .Input("min_filter: float")
3324 .Input("max_filter: float")
3325 .Output("output: out_type")
3326 .Output("min_output: float")
3327 .Output("max_output: float")
3328 .Attr("Tinput: quantizedtype")
3329 .Attr("Tfilter: quantizedtype")
3330 .Attr("out_type: quantizedtype = DT_QINT32")
3331 .Attr("strides: list(int)")
3332 .Attr(GetPaddingAttrString())
3333 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3334 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3335
3336 REGISTER_OP("QuantizedDepthwiseConv2DWithBias")
3337 .Input("input: Tinput")
3338 .Input("filter: Tfilter")
3339 .Input("bias: float")
3340 .Input("min_input: float")
3341 .Input("max_input: float")
3342 .Input("min_filter: float")
3343 .Input("max_filter: float")
3344 .Output("output: out_type")
3345 .Output("min_output: float")
3346 .Output("max_output: float")
3347 .Attr("Tinput: quantizedtype")
3348 .Attr("Tfilter: quantizedtype")
3349 .Attr("out_type: quantizedtype = DT_QINT32")
3350 .Attr("strides: list(int)")
3351 .Attr(GetPaddingAttrString())
3352 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3353 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3354
3355 REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndRelu")
3356 .Input("input: Tinput")
3357 .Input("filter: Tfilter")
3358 .Input("bias: float")
3359 .Input("min_input: float")
3360 .Input("max_input: float")
3361 .Input("min_filter: float")
3362 .Input("max_filter: float")
3363 .Output("output: out_type")
3364 .Output("min_output: float")
3365 .Output("max_output: float")
3366 .Attr("Tinput: quantizedtype")
3367 .Attr("Tfilter: quantizedtype")
3368 .Attr("out_type: quantizedtype = DT_QINT32")
3369 .Attr("strides: list(int)")
3370 .Attr(GetPaddingAttrString())
3371 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3372 .Attr("padding_list: list(int) = []")
3373 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3374
3375 REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize")
3376 .Input("input: Tinput")
3377 .Input("filter: Tfilter")
3378 .Input("bias: Tbias")
3379 .Input("min_input: float")
3380 .Input("max_input: float")
3381 .Input("min_filter: float")
3382 .Input("max_filter: float")
3383 .Input("min_freezed_output: float")
3384 .Input("max_freezed_output: float")
3385 .Output("output: out_type")
3386 .Output("min_output: float")
3387 .Output("max_output: float")
3388 .Attr("Tinput: quantizedtype")
3389 .Attr("Tfilter: quantizedtype")
3390 .Attr("Tbias: {float, qint32}")
3391 .Attr("out_type: quantizedtype = DT_QUINT8")
3392 .Attr("strides: list(int)")
3393 .Attr(GetPaddingAttrString())
3394 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3395 .Attr("padding_list: list(int) = []")
3396 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3397
3398 REGISTER_OP("IsotonicRegression")
3399 .Input("input: T")
3400 .Output("output: output_dtype")
3401 .Output("segments: int32")
3402 .Attr("T: realnumbertype")
3403 .Attr("output_dtype: {half, bfloat16, float, double} = DT_FLOAT")
__anon84cbdd654402(::tensorflow::shape_inference::InferenceContext* context) 3404 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* context) {
3405 context->set_output(0, context->input(0));
3406 context->set_output(1, context->input(0));
3407 return OkStatus();
3408 });
3409
3410 } // namespace tensorflow
3411