1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <complex>
17
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22
23 namespace tflite {
24 namespace ops {
25 namespace builtin {
26 namespace complex {
27
28 static const int kInputTensor = 0;
29 static const int kOutputTensor = 0;
30
Prepare(TfLiteContext * context,TfLiteNode * node)31 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
32 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
33 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
34
35 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
36
37 TF_LITE_ENSURE(context, input->type == kTfLiteComplex64 ||
38 input->type == kTfLiteComplex128);
39
40 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
41
42 if (input->type == kTfLiteComplex64) {
43 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
44 } else {
45 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat64);
46 }
47
48 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
49 return context->ResizeTensor(context, output, output_shape);
50 }
51
52 template <typename T, typename ExtractF>
ExtractData(const TfLiteTensor * input,ExtractF extract_func,TfLiteTensor * output)53 void ExtractData(const TfLiteTensor* input, ExtractF extract_func,
54 TfLiteTensor* output) {
55 const std::complex<T>* input_data = GetTensorData<std::complex<T>>(input);
56 T* output_data = GetTensorData<T>(output);
57 const int input_size = NumElements(input);
58 for (int i = 0; i < input_size; ++i) {
59 *output_data++ = extract_func(*input_data++);
60 }
61 }
62
EvalReal(TfLiteContext * context,TfLiteNode * node)63 TfLiteStatus EvalReal(TfLiteContext* context, TfLiteNode* node) {
64 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
65
66 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
67
68 switch (input->type) {
69 case kTfLiteComplex64: {
70 ExtractData<float>(
71 input,
72 static_cast<float (*)(const std::complex<float>&)>(std::real<float>),
73 output);
74 break;
75 }
76 case kTfLiteComplex128: {
77 ExtractData<double>(input,
78 static_cast<double (*)(const std::complex<double>&)>(
79 std::real<double>),
80 output);
81 break;
82 }
83 default: {
84 TF_LITE_KERNEL_LOG(context,
85 "Unsupported input type, Real op only supports "
86 "complex input, but got: ",
87 TfLiteTypeGetName(input->type));
88 return kTfLiteError;
89 }
90 }
91
92 return kTfLiteOk;
93 }
94
EvalImag(TfLiteContext * context,TfLiteNode * node)95 TfLiteStatus EvalImag(TfLiteContext* context, TfLiteNode* node) {
96 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
97
98 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
99
100 switch (input->type) {
101 case kTfLiteComplex64: {
102 ExtractData<float>(
103 input,
104 static_cast<float (*)(const std::complex<float>&)>(std::imag<float>),
105 output);
106 break;
107 }
108 case kTfLiteComplex128: {
109 ExtractData<double>(input,
110 static_cast<double (*)(const std::complex<double>&)>(
111 std::imag<double>),
112 output);
113 break;
114 }
115 default: {
116 TF_LITE_KERNEL_LOG(context,
117 "Unsupported input type, Imag op only supports "
118 "complex input, but got: ",
119 TfLiteTypeGetName(input->type));
120 return kTfLiteError;
121 }
122 }
123
124 return kTfLiteOk;
125 }
126
EvalAbs(TfLiteContext * context,TfLiteNode * node)127 TfLiteStatus EvalAbs(TfLiteContext* context, TfLiteNode* node) {
128 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
129 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
130
131 switch (input->type) {
132 case kTfLiteComplex64: {
133 ExtractData<float>(
134 input,
135 static_cast<float (*)(const std::complex<float>&)>(std::abs<float>),
136 output);
137 break;
138 }
139 case kTfLiteComplex128: {
140 ExtractData<double>(input,
141 static_cast<double (*)(const std::complex<double>&)>(
142 std::abs<double>),
143 output);
144 break;
145 }
146 default: {
147 TF_LITE_KERNEL_LOG(context,
148 "Unsupported input type, ComplexAbs op only supports "
149 "complex input, but got: ",
150 TfLiteTypeGetName(input->type));
151 return kTfLiteError;
152 }
153 }
154
155 return kTfLiteOk;
156 }
157
158 } // namespace complex
159
Register_REAL()160 TfLiteRegistration* Register_REAL() {
161 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
162 complex::Prepare, complex::EvalReal};
163 return &r;
164 }
165
Register_IMAG()166 TfLiteRegistration* Register_IMAG() {
167 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
168 complex::Prepare, complex::EvalImag};
169 return &r;
170 }
171
Register_COMPLEX_ABS()172 TfLiteRegistration* Register_COMPLEX_ABS() {
173 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
174 complex::Prepare, complex::EvalAbs};
175 return &r;
176 }
177
178 } // namespace builtin
179 } // namespace ops
180 } // namespace tflite
181