1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_ACCELERATION_TEST_UTIL_INTERNAL_H_
16 #define TENSORFLOW_LITE_KERNELS_ACCELERATION_TEST_UTIL_INTERNAL_H_
17
18 #include <algorithm>
19 #include <atomic>
20 #include <functional>
21 #include <optional>
22 #include <iterator>
23 #include <optional>
24 #include <string>
25 #include <vector>
26
27 namespace tflite {
28
29 // Reads the acceleration configuration, handles comments and empty lines and
30 // the basic data conversion format (split into key, value, recognition of
31 // the line being a white or black list entry) and gives the data to the
32 // consumer to be inserted into the target collection.
33 void ReadAccelerationConfig(
34 const char* config,
35 const std::function<void(std::string, std::string, bool)>& consumer);
36
37 template <typename T>
38 class ConfigurationEntry {
39 public:
ConfigurationEntry(const std::string & test_id_rex,T test_config,bool is_denylist)40 ConfigurationEntry(const std::string& test_id_rex, T test_config,
41 bool is_denylist)
42 : test_id_rex_(test_id_rex),
43 test_config_(test_config),
44 is_denylist_(is_denylist) {}
45
Matches(const std::string & test_id)46 bool Matches(const std::string& test_id) {
47 // Always return false on Android because there is no re2 library available.
48 return false;
49 }
IsDenylistEntry()50 bool IsDenylistEntry() const { return is_denylist_; }
TestConfig()51 const T& TestConfig() const { return test_config_; }
52
TestIdRex()53 const std::string& TestIdRex() const { return test_id_rex_; }
54
55 private:
56 std::string test_id_rex_;
57 T test_config_;
58 bool is_denylist_;
59 };
60
61 // Returns the acceleration test configuration for the given test id and
62 // the given acceleration configuration type.
63 // The configuration type is responsible of providing the test configuration
64 // and the parse function to convert configuration lines into configuration
65 // objects.
66 template <typename T>
GetAccelerationTestParam(std::string test_id)67 std::optional<T> GetAccelerationTestParam(std::string test_id) {
68 static std::atomic<std::vector<ConfigurationEntry<T>>*> test_config_ptr;
69
70 if (test_config_ptr.load() == nullptr) {
71 auto config = new std::vector<ConfigurationEntry<T>>();
72
73 auto consumer = [&config](std::string key, std::string value_str,
74 bool is_denylist) mutable {
75 T value = T::ParseConfigurationLine(value_str);
76 config->push_back(ConfigurationEntry<T>(key, value, is_denylist));
77 };
78
79 ReadAccelerationConfig(T::kAccelerationTestConfig, consumer);
80
81 // Even if it has been already set, it would be just replaced with the
82 // same value, just freeing the old value to avoid leaks
83 auto* prev_val = test_config_ptr.exchange(config);
84 delete prev_val;
85 }
86
87 const std::vector<ConfigurationEntry<T>>* test_config =
88 test_config_ptr.load();
89
90 const auto test_config_iter = std::find_if(
91 test_config->begin(), test_config->end(),
92 [&test_id](ConfigurationEntry<T> elem) { return elem.Matches(test_id); });
93 if (test_config_iter != test_config->end() &&
94 !test_config_iter->IsDenylistEntry()) {
95 return std::optional<T>(test_config_iter->TestConfig());
96 } else {
97 return std::optional<T>();
98 }
99 }
100
101 } // namespace tflite
102
103 #endif // TENSORFLOW_LITE_KERNELS_ACCELERATION_TEST_UTIL_INTERNAL_H_
104