1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_AGGREGATION_CORE_DATATYPE_H_ 18 #define FCP_AGGREGATION_CORE_DATATYPE_H_ 19 20 #include <cstdint> 21 22 #include "fcp/base/monitoring.h" 23 24 #ifndef FCP_NANOLIBC 25 #include "absl/strings/string_view.h" 26 #include "fcp/aggregation/core/tensor.pb.h" 27 #endif 28 29 namespace fcp { 30 namespace aggregation { 31 32 #ifndef FCP_NANOLIBC 33 // Unless when building with Nanolibc, we can use absl::string_view directly. 34 using string_view = absl::string_view; 35 #else 36 // TODO(team): Minimal implementation of string_view for bare-metal 37 // environment. 38 struct string_view {}; 39 #endif 40 41 #ifdef FCP_NANOLIBC 42 // TODO(team): Derive these values from tensor.proto built with Nanopb 43 enum DataType { 44 // The constants below should be kept in sync with tensorflow::Datatype: 45 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto 46 // While not strictly required, that has a number of benefits. 47 DT_INVALID = 0, 48 DT_FLOAT = 1, 49 DT_DOUBLE = 2, 50 DT_INT32 = 3, 51 DT_STRING = 7, 52 DT_INT64 = 9, 53 54 // TODO(team): Add other types. 55 // This should be a small subset of tensorflow::DataType types and include 56 // only simple numeric types and floating point types. 57 // 58 // When a tensor DT_ type is added here, it must also be added to the list of 59 // MATCH_TYPE_AND_DTYPE macros below and to the CASES macro. 60 }; 61 #endif // FCP_NANOLIBC 62 63 namespace internal { 64 65 // This struct is used to map typename T to DataType and specify other traits 66 // of typename T. 67 template <typename T> 68 struct TypeTraits { 69 constexpr static DataType kDataType = DT_INVALID; 70 }; 71 72 #define MATCH_TYPE_AND_DTYPE(TYPE, DTYPE) \ 73 template <> \ 74 struct TypeTraits<TYPE> { \ 75 constexpr static DataType kDataType = DTYPE; \ 76 } 77 78 // Mapping of native types to DT_ types. 79 // TODO(team): Add other types. 80 MATCH_TYPE_AND_DTYPE(float, DT_FLOAT); 81 MATCH_TYPE_AND_DTYPE(double, DT_DOUBLE); 82 MATCH_TYPE_AND_DTYPE(int32_t, DT_INT32); 83 MATCH_TYPE_AND_DTYPE(int64_t, DT_INT64); 84 MATCH_TYPE_AND_DTYPE(string_view, DT_STRING); 85 86 // The macros DTYPE_CASE and DTYPE_CASES are used to translate Tensor DataType 87 // to strongly typed calls of code parameterized with the template typename 88 // TYPE_ARG. 89 // 90 // For example, let's say there is a function that takes an AggVector<T>: 91 // template <typename T> 92 // void DoSomething(AggVector<T> agg_vector) { ... } 93 // 94 // Given a Tensor, the following code can be used to make a DoSomething call: 95 // DTYPE_CASES(tensor.dtype(), T, DoSomething(tensor.AsAggVector<T>())); 96 // 97 // The second parameter specifies the type argument to be used as the template 98 // parameter in the statement in the third argument. 99 100 #define SINGLE_ARG(...) __VA_ARGS__ 101 #define DTYPE_CASE(TYPE, TYPE_ARG, STMTS) \ 102 case internal::TypeTraits<TYPE>::kDataType: { \ 103 typedef TYPE TYPE_ARG; \ 104 STMTS; \ 105 break; \ 106 } 107 108 // TODO(team): Add other types. 109 #define DTYPE_CASES(TYPE_ENUM, TYPE_ARG, STMTS) \ 110 switch (TYPE_ENUM) { \ 111 DTYPE_CASE(float, TYPE_ARG, SINGLE_ARG(STMTS)) \ 112 DTYPE_CASE(double, TYPE_ARG, SINGLE_ARG(STMTS)) \ 113 DTYPE_CASE(int32_t, TYPE_ARG, SINGLE_ARG(STMTS)) \ 114 DTYPE_CASE(int64_t, TYPE_ARG, SINGLE_ARG(STMTS)) \ 115 DTYPE_CASE(string_view, TYPE_ARG, SINGLE_ARG(STMTS)) \ 116 case DT_INVALID: \ 117 FCP_LOG(FATAL) << "Invalid type"; \ 118 break; \ 119 default: \ 120 FCP_LOG(FATAL) << "Unknown type"; \ 121 } 122 123 } // namespace internal 124 125 } // namespace aggregation 126 } // namespace fcp 127 128 #endif // FCP_AGGREGATION_CORE_DATATYPE_H_ 129