1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace matrix_set_diag {
28
29 constexpr int kInputTensor = 0;
30 constexpr int kDiagonalTensor = 1;
31 constexpr int kOutputTensor = 0;
32
Prepare(TfLiteContext * context,TfLiteNode * node)33 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
34 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
35 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
36 const TfLiteTensor* input;
37 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
38 TfLiteIntArray* input_dims = input->dims;
39 int input_dims_size = input_dims->size;
40 TF_LITE_ENSURE(context, input_dims_size >= 2);
41
42 TfLiteTensor* output;
43 TF_LITE_ENSURE_OK(context,
44 GetOutputSafe(context, node, kOutputTensor, &output));
45
46 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size);
47 for (int i = 0; i < input_dims_size; i++) {
48 output_shape->data[i] = input_dims->data[i];
49 }
50
51 // Resize the output tensor to the same size as the input tensor.
52 output->type = input->type;
53 TF_LITE_ENSURE_OK(context,
54 context->ResizeTensor(context, output, output_shape));
55
56 return kTfLiteOk;
57 }
58
59 // Fill the tensor to make a diagonal matrix in each batch, i.e., when
60 // row index and column index are the same, fill with the next diagonal value.
61 // All other entries are the same as the input value.
62 // TODO(b/128636574) Move to reference_ops.
63 template <typename T>
FillDiagImpl(const T * in,const T * diag,T * out,const int batch_size,const int row_size,const int col_size)64 void FillDiagImpl(const T* in, const T* diag, T* out, const int batch_size,
65 const int row_size, const int col_size) {
66 int idx = 0;
67 for (int b = 0; b < batch_size; b++) {
68 for (int i = 0; i < row_size; i++) {
69 for (int j = 0; j < col_size; ++j) {
70 // diag values go on the diagonal, in values elsewhere
71 if (i == j) {
72 out[i * col_size + j] = diag[idx];
73 idx++;
74 } else {
75 out[i * col_size + j] = in[i * col_size + j];
76 }
77 }
78 }
79 out += row_size * col_size;
80 in += row_size * col_size;
81 }
82 }
83
84 template <typename T>
FillDiag(const TfLiteTensor * input,const TfLiteTensor * diag,TfLiteTensor * output,const int batch_size,const int row_size,const int col_size)85 void FillDiag(const TfLiteTensor* input, const TfLiteTensor* diag,
86 TfLiteTensor* output, const int batch_size, const int row_size,
87 const int col_size) {
88 FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(diag),
89 GetTensorData<T>(output), batch_size, row_size, col_size);
90 }
91
92 // Fill a tensor with given "diag" values on the diagonal, input values
93 // elsewhere.
FillDiagHelper(const TfLiteTensor * input,const TfLiteTensor * diag,TfLiteTensor * output)94 void FillDiagHelper(const TfLiteTensor* input, const TfLiteTensor* diag,
95 TfLiteTensor* output) {
96 const int num_output_dims = output->dims->size;
97 int batch_size = 1;
98 for (int i = 0; i < num_output_dims - 2; ++i) {
99 batch_size *= output->dims->data[i];
100 }
101
102 const int row_size = output->dims->data[num_output_dims - 2];
103 const int col_size = output->dims->data[num_output_dims - 1];
104 switch (output->type) {
105 case kTfLiteInt64: {
106 return FillDiag<int64_t>(input, diag, output, batch_size, row_size,
107 col_size);
108 }
109 case kTfLiteInt32: {
110 return FillDiag<int32_t>(input, diag, output, batch_size, row_size,
111 col_size);
112 }
113 case kTfLiteInt16: {
114 return FillDiag<int16_t>(input, diag, output, batch_size, row_size,
115 col_size);
116 }
117 case kTfLiteInt8: {
118 return FillDiag<int8_t>(input, diag, output, batch_size, row_size,
119 col_size);
120 }
121 case kTfLiteUInt8: {
122 return FillDiag<uint8_t>(input, diag, output, batch_size, row_size,
123 col_size);
124 }
125 default:
126 return FillDiag<float>(input, diag, output, batch_size, row_size,
127 col_size);
128 }
129 }
130
Eval(TfLiteContext * context,TfLiteNode * node)131 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
132 TfLiteTensor* output;
133 TF_LITE_ENSURE_OK(context,
134 GetOutputSafe(context, node, kOutputTensor, &output));
135 const TfLiteTensor* input;
136 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
137 const TfLiteTensor* diag;
138 TF_LITE_ENSURE_OK(context,
139 GetInputSafe(context, node, kDiagonalTensor, &diag));
140 FillDiagHelper(input, diag, output);
141 return kTfLiteOk;
142 }
143
144 } // namespace matrix_set_diag
145
Register_MATRIX_SET_DIAG()146 TfLiteRegistration* Register_MATRIX_SET_DIAG() {
147 static TfLiteRegistration r = {nullptr, nullptr, matrix_set_diag::Prepare,
148 matrix_set_diag::Eval};
149 return &r;
150 }
151
152 } // namespace builtin
153 } // namespace ops
154 } // namespace tflite
155