xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/complex_support.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <complex>
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 
23 namespace tflite {
24 namespace ops {
25 namespace builtin {
26 namespace complex {
27 
28 static const int kInputTensor = 0;
29 static const int kOutputTensor = 0;
30 
Prepare(TfLiteContext * context,TfLiteNode * node)31 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
32   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
33   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
34 
35   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
36 
37   TF_LITE_ENSURE(context, input->type == kTfLiteComplex64 ||
38                               input->type == kTfLiteComplex128);
39 
40   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
41 
42   if (input->type == kTfLiteComplex64) {
43     TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
44   } else {
45     TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat64);
46   }
47 
48   TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
49   return context->ResizeTensor(context, output, output_shape);
50 }
51 
52 template <typename T, typename ExtractF>
ExtractData(const TfLiteTensor * input,ExtractF extract_func,TfLiteTensor * output)53 void ExtractData(const TfLiteTensor* input, ExtractF extract_func,
54                  TfLiteTensor* output) {
55   const std::complex<T>* input_data = GetTensorData<std::complex<T>>(input);
56   T* output_data = GetTensorData<T>(output);
57   const int input_size = NumElements(input);
58   for (int i = 0; i < input_size; ++i) {
59     *output_data++ = extract_func(*input_data++);
60   }
61 }
62 
EvalReal(TfLiteContext * context,TfLiteNode * node)63 TfLiteStatus EvalReal(TfLiteContext* context, TfLiteNode* node) {
64   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
65 
66   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
67 
68   switch (input->type) {
69     case kTfLiteComplex64: {
70       ExtractData<float>(
71           input,
72           static_cast<float (*)(const std::complex<float>&)>(std::real<float>),
73           output);
74       break;
75     }
76     case kTfLiteComplex128: {
77       ExtractData<double>(input,
78                           static_cast<double (*)(const std::complex<double>&)>(
79                               std::real<double>),
80                           output);
81       break;
82     }
83     default: {
84       TF_LITE_KERNEL_LOG(context,
85                          "Unsupported input type, Real op only supports "
86                          "complex input, but got: ",
87                          TfLiteTypeGetName(input->type));
88       return kTfLiteError;
89     }
90   }
91 
92   return kTfLiteOk;
93 }
94 
EvalImag(TfLiteContext * context,TfLiteNode * node)95 TfLiteStatus EvalImag(TfLiteContext* context, TfLiteNode* node) {
96   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
97 
98   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
99 
100   switch (input->type) {
101     case kTfLiteComplex64: {
102       ExtractData<float>(
103           input,
104           static_cast<float (*)(const std::complex<float>&)>(std::imag<float>),
105           output);
106       break;
107     }
108     case kTfLiteComplex128: {
109       ExtractData<double>(input,
110                           static_cast<double (*)(const std::complex<double>&)>(
111                               std::imag<double>),
112                           output);
113       break;
114     }
115     default: {
116       TF_LITE_KERNEL_LOG(context,
117                          "Unsupported input type, Imag op only supports "
118                          "complex input, but got: ",
119                          TfLiteTypeGetName(input->type));
120       return kTfLiteError;
121     }
122   }
123 
124   return kTfLiteOk;
125 }
126 
EvalAbs(TfLiteContext * context,TfLiteNode * node)127 TfLiteStatus EvalAbs(TfLiteContext* context, TfLiteNode* node) {
128   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
129   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
130 
131   switch (input->type) {
132     case kTfLiteComplex64: {
133       ExtractData<float>(
134           input,
135           static_cast<float (*)(const std::complex<float>&)>(std::abs<float>),
136           output);
137       break;
138     }
139     case kTfLiteComplex128: {
140       ExtractData<double>(input,
141                           static_cast<double (*)(const std::complex<double>&)>(
142                               std::abs<double>),
143                           output);
144       break;
145     }
146     default: {
147       TF_LITE_KERNEL_LOG(context,
148                          "Unsupported input type, ComplexAbs op only supports "
149                          "complex input, but got: ",
150                          TfLiteTypeGetName(input->type));
151       return kTfLiteError;
152     }
153   }
154 
155   return kTfLiteOk;
156 }
157 
158 }  // namespace complex
159 
Register_REAL()160 TfLiteRegistration* Register_REAL() {
161   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
162                                  complex::Prepare, complex::EvalReal};
163   return &r;
164 }
165 
Register_IMAG()166 TfLiteRegistration* Register_IMAG() {
167   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
168                                  complex::Prepare, complex::EvalImag};
169   return &r;
170 }
171 
Register_COMPLEX_ABS()172 TfLiteRegistration* Register_COMPLEX_ABS() {
173   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
174                                  complex::Prepare, complex::EvalAbs};
175   return &r;
176 }
177 
178 }  // namespace builtin
179 }  // namespace ops
180 }  // namespace tflite
181