xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/cumsum.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/gpu/common/tasks/cumsum.h"
16 
17 #include <string>
18 #include <utility>
19 
20 #include "tensorflow/lite/delegates/gpu/common/operations.h"
21 #include "tensorflow/lite/delegates/gpu/common/shape.h"
22 
23 namespace tflite {
24 namespace gpu {
25 
GetCumsumCode(const OperationDef & op_def)26 void Cumsum::GetCumsumCode(const OperationDef& op_def) {
27   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
28   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
29   std::map<Axis, std::string> task_sizes = {
30       {Axis::WIDTH, "args.src_tensor.Width()"},
31       {Axis::HEIGHT, "args.src_tensor.Height()"},
32       {Axis::DEPTH, "args.src_tensor.Depth()"},
33       {Axis::CHANNELS, "args.src_tensor.Slices()"},
34       {Axis::BATCH, "args.src_tensor.Batch()"},
35   };
36   std::string limit = task_sizes[axis_];
37   task_sizes[axis_] = "1";
38   std::map<Axis, std::string> index_name = {
39       {Axis::WIDTH, "X"},    {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "Z"},
40       {Axis::CHANNELS, "S"}, {Axis::BATCH, "B"},
41   };
42   std::string indexes = "X, Y";
43   std::string c;
44   c += "MAIN_FUNCTION($0) {\n";
45   if (definition_.dst_tensors[0].HasAxis(Axis::DEPTH)) {
46     indexes += ", Z";
47     c += "  int linear_id = GLOBAL_ID_1;\n";
48     c += "  int Y = linear_id % " + task_sizes[Axis::HEIGHT] + ";\n";
49     c += "  int D = linear_id / " + task_sizes[Axis::HEIGHT] + ";\n";
50     c += "  if (D >= " + task_sizes[Axis::DEPTH] + ") return;\n";
51   } else {
52     c += "  int Y = GLOBAL_ID_1;\n";
53     c += "  if (Y >= " + task_sizes[Axis::HEIGHT] + ") return;\n";
54   }
55   indexes += ", S";
56   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
57     indexes += ", B";
58     c += "  int linear_id = GLOBAL_ID_0;\n";
59     c += "  int X = linear_id / " + task_sizes[Axis::BATCH] + ";\n";
60     c += "  int B = linear_id % " + task_sizes[Axis::BATCH] + ";\n";
61     c += "  if (X >= " + task_sizes[Axis::WIDTH] + ") return;\n";
62   } else {
63     c += "  int X = GLOBAL_ID_0;\n";
64     c += "  if (X >= " + task_sizes[Axis::WIDTH] + ") return;\n";
65   }
66   c += "  int S = GLOBAL_ID_2;\n";
67   c += "  if (S >= " + task_sizes[Axis::CHANNELS] + ") return;\n";
68   c += "  args.src_tensor::type res = args.src_tensor::zero_value;\n";
69   c += "  for (; " + index_name[axis_] + " < " + limit + "; " +
70        index_name[axis_] + "++) {\n";
71   c += "    args.src_tensor::type curr = args.src_tensor.Read(" + indexes +
72        ");\n";
73   if (axis_ == Axis::CHANNELS) {
74     c += "    res.x = res.w + curr.x;\n";
75     c += "    res.y = res.x + curr.y;\n";
76     c += "    res.z = res.y + curr.z;\n";
77     c += "    res.w = res.z + curr.w;\n";
78   } else {
79     c += "    res += curr;\n";
80   }
81   c += "    args.dst_tensor.Write(res, " + indexes + ");\n";
82   c += "  }\n";
83   c += "}\n";
84   code_ = c;
85 }
86 
GetGridSize() const87 int3 Cumsum::GetGridSize() const {
88   const int width = axis_ == Axis::WIDTH ? 1 : src_[0]->Width();
89   const int height = axis_ == Axis::HEIGHT ? 1 : src_[0]->Height();
90   const int depth = axis_ == Axis::DEPTH ? 1 : src_[0]->Depth();
91   const int batch = axis_ == Axis::BATCH ? 1 : src_[0]->Batch();
92   const int slices = axis_ == Axis::CHANNELS ? 1 : src_[0]->Slices();
93   const int grid_x = width * batch;
94   const int grid_y = height * depth;
95   const int grid_z = slices;
96   return int3(grid_x, grid_y, grid_z);
97 }
98 
Cumsum(Cumsum && operation)99 Cumsum::Cumsum(Cumsum&& operation)
100     : GPUOperation(std::move(operation)), axis_(operation.axis_) {}
101 
operator =(Cumsum && operation)102 Cumsum& Cumsum::operator=(Cumsum&& operation) {
103   if (this != &operation) {
104     axis_ = operation.axis_;
105     GPUOperation::operator=(std::move(operation));
106   }
107   return *this;
108 }
109 
CreateCumsum(const OperationDef & definition,const CumsumAttributes & attr)110 Cumsum CreateCumsum(const OperationDef& definition,
111                     const CumsumAttributes& attr) {
112   Cumsum op(definition, attr.axis);
113   op.GetCumsumCode(definition);
114   return op;
115 }
116 
117 }  // namespace gpu
118 }  // namespace tflite
119