xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/gradients/grad_helper.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/gradients/grad_helper.h"
17 
18 #include "tensorflow/cc/ops/array_ops.h"
19 #include "tensorflow/cc/ops/data_flow_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 
22 namespace tensorflow {
23 
24 using tensorflow::ops::Add;
25 using tensorflow::ops::Const;
26 using tensorflow::ops::DynamicStitch;
27 using tensorflow::ops::Mod;
28 using tensorflow::ops::OnesLike;
29 using tensorflow::ops::Range;
30 using tensorflow::ops::Size;
31 
ReducedShapeHelper(const Scope & scope,const Output & input_shape,const Output & reduction_axes)32 Output ReducedShapeHelper(const Scope& scope, const Output& input_shape,
33                           const Output& reduction_axes) {
34   auto zero = Const(scope, 0);
35   auto one = Const(scope, 1);
36 
37   // Running example in comments
38   // input_shape = [2, 3, 5, 7]
39   // axes = [1, 2]
40   // The result (a shape after a reduction with keep_dims=True)
41   // [2, 1, 1, 7]
42   //
43   // We can treat each entry in axes as an index into input_shape that
44   // should be replaced by 1.
45   // We use DynamicStitch to do this.
46 
47   // input_rank = 4
48   auto input_rank = Size(scope, input_shape);
49 
50   // Normalize any negative indices in the reduction_axes to positive
51   // values.
52   auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank);
53 
54   // This [0..input_rank) range of integers is used in DynamicStitch to
55   // first copy input_shape to the result.
56   // input_rank_range = [0, 1, 2, 3]
57   auto input_rank_range = Range(scope, zero, input_rank, one);
58 
59   // A 1-filled tensor with the same shape as axes. DynamicStitch will
60   // merge these 1s (using axes for indices) to the correct
61   // position in the result.
62   // axes_ones = [1, 1]
63   auto axes_ones = OnesLike(scope, axes);
64 
65   // using DynamicStitch:
66   // indices = { input_rank_range, axes }
67   //         = { [0, 1, 2, 3], [1, 2] }
68   // data = { input_shape, axes_ones }
69   //      = { [2, 3, 5, 7], [1, 1] }
70   // The input_rank_range entry in indices first replicates the
71   // input_shape to the result.
72   // The axes entry in indices then moves a 1 to each of its entries,
73   // resulting in
74   // [2, 1, 1, 7]
75   std::vector<Output> indices = {input_rank_range, axes};
76   std::vector<Output> data = {input_shape, axes_ones};
77   return DynamicStitch(scope, indices, data);
78 }
79 
80 }  // namespace tensorflow
81