1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/gpu/common/tasks/cumsum.h"
16
17 #include <string>
18 #include <utility>
19
20 #include "tensorflow/lite/delegates/gpu/common/operations.h"
21 #include "tensorflow/lite/delegates/gpu/common/shape.h"
22
23 namespace tflite {
24 namespace gpu {
25
GetCumsumCode(const OperationDef & op_def)26 void Cumsum::GetCumsumCode(const OperationDef& op_def) {
27 AddSrcTensor("src_tensor", op_def.src_tensors[0]);
28 AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
29 std::map<Axis, std::string> task_sizes = {
30 {Axis::WIDTH, "args.src_tensor.Width()"},
31 {Axis::HEIGHT, "args.src_tensor.Height()"},
32 {Axis::DEPTH, "args.src_tensor.Depth()"},
33 {Axis::CHANNELS, "args.src_tensor.Slices()"},
34 {Axis::BATCH, "args.src_tensor.Batch()"},
35 };
36 std::string limit = task_sizes[axis_];
37 task_sizes[axis_] = "1";
38 std::map<Axis, std::string> index_name = {
39 {Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "Z"},
40 {Axis::CHANNELS, "S"}, {Axis::BATCH, "B"},
41 };
42 std::string indexes = "X, Y";
43 std::string c;
44 c += "MAIN_FUNCTION($0) {\n";
45 if (definition_.dst_tensors[0].HasAxis(Axis::DEPTH)) {
46 indexes += ", Z";
47 c += " int linear_id = GLOBAL_ID_1;\n";
48 c += " int Y = linear_id % " + task_sizes[Axis::HEIGHT] + ";\n";
49 c += " int D = linear_id / " + task_sizes[Axis::HEIGHT] + ";\n";
50 c += " if (D >= " + task_sizes[Axis::DEPTH] + ") return;\n";
51 } else {
52 c += " int Y = GLOBAL_ID_1;\n";
53 c += " if (Y >= " + task_sizes[Axis::HEIGHT] + ") return;\n";
54 }
55 indexes += ", S";
56 if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
57 indexes += ", B";
58 c += " int linear_id = GLOBAL_ID_0;\n";
59 c += " int X = linear_id / " + task_sizes[Axis::BATCH] + ";\n";
60 c += " int B = linear_id % " + task_sizes[Axis::BATCH] + ";\n";
61 c += " if (X >= " + task_sizes[Axis::WIDTH] + ") return;\n";
62 } else {
63 c += " int X = GLOBAL_ID_0;\n";
64 c += " if (X >= " + task_sizes[Axis::WIDTH] + ") return;\n";
65 }
66 c += " int S = GLOBAL_ID_2;\n";
67 c += " if (S >= " + task_sizes[Axis::CHANNELS] + ") return;\n";
68 c += " args.src_tensor::type res = args.src_tensor::zero_value;\n";
69 c += " for (; " + index_name[axis_] + " < " + limit + "; " +
70 index_name[axis_] + "++) {\n";
71 c += " args.src_tensor::type curr = args.src_tensor.Read(" + indexes +
72 ");\n";
73 if (axis_ == Axis::CHANNELS) {
74 c += " res.x = res.w + curr.x;\n";
75 c += " res.y = res.x + curr.y;\n";
76 c += " res.z = res.y + curr.z;\n";
77 c += " res.w = res.z + curr.w;\n";
78 } else {
79 c += " res += curr;\n";
80 }
81 c += " args.dst_tensor.Write(res, " + indexes + ");\n";
82 c += " }\n";
83 c += "}\n";
84 code_ = c;
85 }
86
GetGridSize() const87 int3 Cumsum::GetGridSize() const {
88 const int width = axis_ == Axis::WIDTH ? 1 : src_[0]->Width();
89 const int height = axis_ == Axis::HEIGHT ? 1 : src_[0]->Height();
90 const int depth = axis_ == Axis::DEPTH ? 1 : src_[0]->Depth();
91 const int batch = axis_ == Axis::BATCH ? 1 : src_[0]->Batch();
92 const int slices = axis_ == Axis::CHANNELS ? 1 : src_[0]->Slices();
93 const int grid_x = width * batch;
94 const int grid_y = height * depth;
95 const int grid_z = slices;
96 return int3(grid_x, grid_y, grid_z);
97 }
98
Cumsum(Cumsum && operation)99 Cumsum::Cumsum(Cumsum&& operation)
100 : GPUOperation(std::move(operation)), axis_(operation.axis_) {}
101
operator =(Cumsum && operation)102 Cumsum& Cumsum::operator=(Cumsum&& operation) {
103 if (this != &operation) {
104 axis_ = operation.axis_;
105 GPUOperation::operator=(std::move(operation));
106 }
107 return *this;
108 }
109
CreateCumsum(const OperationDef & definition,const CumsumAttributes & attr)110 Cumsum CreateCumsum(const OperationDef& definition,
111 const CumsumAttributes& attr) {
112 Cumsum op(definition, attr.axis);
113 op.GetCumsumCode(definition);
114 return op;
115 }
116
117 } // namespace gpu
118 } // namespace tflite
119