xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/matrix_set_diag.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace matrix_set_diag {
28 
29 constexpr int kInputTensor = 0;
30 constexpr int kDiagonalTensor = 1;
31 constexpr int kOutputTensor = 0;
32 
Prepare(TfLiteContext * context,TfLiteNode * node)33 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
34   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
35   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
36   const TfLiteTensor* input;
37   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
38   TfLiteIntArray* input_dims = input->dims;
39   int input_dims_size = input_dims->size;
40   TF_LITE_ENSURE(context, input_dims_size >= 2);
41 
42   TfLiteTensor* output;
43   TF_LITE_ENSURE_OK(context,
44                     GetOutputSafe(context, node, kOutputTensor, &output));
45 
46   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size);
47   for (int i = 0; i < input_dims_size; i++) {
48     output_shape->data[i] = input_dims->data[i];
49   }
50 
51   // Resize the output tensor to the same size as the input tensor.
52   output->type = input->type;
53   TF_LITE_ENSURE_OK(context,
54                     context->ResizeTensor(context, output, output_shape));
55 
56   return kTfLiteOk;
57 }
58 
59 // Fill the tensor to make a diagonal matrix in each batch, i.e., when
60 // row index and column index are the same, fill with the next diagonal value.
61 // All other entries are the same as the input value.
62 // TODO(b/128636574) Move to reference_ops.
63 template <typename T>
FillDiagImpl(const T * in,const T * diag,T * out,const int batch_size,const int row_size,const int col_size)64 void FillDiagImpl(const T* in, const T* diag, T* out, const int batch_size,
65                   const int row_size, const int col_size) {
66   int idx = 0;
67   for (int b = 0; b < batch_size; b++) {
68     for (int i = 0; i < row_size; i++) {
69       for (int j = 0; j < col_size; ++j) {
70         // diag values go on the diagonal, in values elsewhere
71         if (i == j) {
72           out[i * col_size + j] = diag[idx];
73           idx++;
74         } else {
75           out[i * col_size + j] = in[i * col_size + j];
76         }
77       }
78     }
79     out += row_size * col_size;
80     in += row_size * col_size;
81   }
82 }
83 
84 template <typename T>
FillDiag(const TfLiteTensor * input,const TfLiteTensor * diag,TfLiteTensor * output,const int batch_size,const int row_size,const int col_size)85 void FillDiag(const TfLiteTensor* input, const TfLiteTensor* diag,
86               TfLiteTensor* output, const int batch_size, const int row_size,
87               const int col_size) {
88   FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(diag),
89                   GetTensorData<T>(output), batch_size, row_size, col_size);
90 }
91 
92 // Fill a tensor with given "diag" values on the diagonal, input values
93 // elsewhere.
FillDiagHelper(const TfLiteTensor * input,const TfLiteTensor * diag,TfLiteTensor * output)94 void FillDiagHelper(const TfLiteTensor* input, const TfLiteTensor* diag,
95                     TfLiteTensor* output) {
96   const int num_output_dims = output->dims->size;
97   int batch_size = 1;
98   for (int i = 0; i < num_output_dims - 2; ++i) {
99     batch_size *= output->dims->data[i];
100   }
101 
102   const int row_size = output->dims->data[num_output_dims - 2];
103   const int col_size = output->dims->data[num_output_dims - 1];
104   switch (output->type) {
105     case kTfLiteInt64: {
106       return FillDiag<int64_t>(input, diag, output, batch_size, row_size,
107                                col_size);
108     }
109     case kTfLiteInt32: {
110       return FillDiag<int32_t>(input, diag, output, batch_size, row_size,
111                                col_size);
112     }
113     case kTfLiteInt16: {
114       return FillDiag<int16_t>(input, diag, output, batch_size, row_size,
115                                col_size);
116     }
117     case kTfLiteInt8: {
118       return FillDiag<int8_t>(input, diag, output, batch_size, row_size,
119                               col_size);
120     }
121     case kTfLiteUInt8: {
122       return FillDiag<uint8_t>(input, diag, output, batch_size, row_size,
123                                col_size);
124     }
125     default:
126       return FillDiag<float>(input, diag, output, batch_size, row_size,
127                              col_size);
128   }
129 }
130 
Eval(TfLiteContext * context,TfLiteNode * node)131 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
132   TfLiteTensor* output;
133   TF_LITE_ENSURE_OK(context,
134                     GetOutputSafe(context, node, kOutputTensor, &output));
135   const TfLiteTensor* input;
136   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
137   const TfLiteTensor* diag;
138   TF_LITE_ENSURE_OK(context,
139                     GetInputSafe(context, node, kDiagonalTensor, &diag));
140   FillDiagHelper(input, diag, output);
141   return kTfLiteOk;
142 }
143 
144 }  // namespace matrix_set_diag
145 
Register_MATRIX_SET_DIAG()146 TfLiteRegistration* Register_MATRIX_SET_DIAG() {
147   static TfLiteRegistration r = {nullptr, nullptr, matrix_set_diag::Prepare,
148                                  matrix_set_diag::Eval};
149   return &r;
150 }
151 
152 }  // namespace builtin
153 }  // namespace ops
154 }  // namespace tflite
155