1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/runtime_topk.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <limits>
21 #include <memory>
22 #include <numeric>
23 #include <vector>
24
25 #include "absl/base/dynamic_annotations.h"
26
27 template <typename T>
TopK(int64_t batch_size,int64_t input_size,int64_t k,const T * values,T * out_values,int32_t * out_indices)28 static void TopK(int64_t batch_size, int64_t input_size, int64_t k,
29 const T* values, T* out_values, int32_t* out_indices) {
30 // 'values' is managed by the JIT code, so msan can't tell they are
31 // initialized.
32 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(values,
33 input_size * batch_size * sizeof(T));
34
35 std::vector<int32_t> temp_indices(input_size);
36 for (int64_t batch = 0; batch != batch_size; ++batch) {
37 std::iota(temp_indices.begin(), temp_indices.end(), 0);
38
39 const T* values_batch = values + batch * input_size;
40
41 auto convert_to_int = [](T value) {
42 uint32_t x;
43 std::memcpy(&x, &value, sizeof(x));
44 return static_cast<int32_t>(x) < 0
45 ? std::numeric_limits<int32_t>::max() - x
46 : x;
47 };
48
49 auto kth_element = temp_indices.begin() + k;
50 std::partial_sort(temp_indices.begin(), kth_element, temp_indices.end(),
51 [&](size_t i1, size_t i2) {
52 // Do the comparison in integers to enforce a total
53 // order of -NaN < -Inf < -0 < +0 < +Inf < +NaN.
54 int32_t v1 = convert_to_int(values_batch[i1]);
55 int32_t v2 = convert_to_int(values_batch[i2]);
56 if (v1 == v2) {
57 return i1 < i2; // Stabilize sorting.
58 }
59 return v1 > v2;
60 });
61
62 T* out_values_batch = out_values + batch * k;
63 int32_t* out_indices_batch = out_indices + batch * k;
64 std::copy(temp_indices.begin(), kth_element, out_indices_batch);
65 for (int64_t i = 0; i < k; i++) {
66 out_values_batch[i] = values_batch[temp_indices[i]];
67 }
68 }
69 }
70
__xla_cpu_runtime_TopKF32(int64_t batch_size,int64_t input_size,int64_t k,const float * values,float * out_values,int32_t * out_indices)71 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TopKF32(
72 int64_t batch_size, int64_t input_size, int64_t k, const float* values,
73 float* out_values, int32_t* out_indices) {
74 TopK(batch_size, input_size, k, values, out_values, out_indices);
75 }
76