xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/datatype.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_AGGREGATION_CORE_DATATYPE_H_
18 #define FCP_AGGREGATION_CORE_DATATYPE_H_
19 
20 #include <cstdint>
21 
22 #include "fcp/base/monitoring.h"
23 
24 #ifndef FCP_NANOLIBC
25 #include "absl/strings/string_view.h"
26 #include "fcp/aggregation/core/tensor.pb.h"
27 #endif
28 
29 namespace fcp {
30 namespace aggregation {
31 
32 #ifndef FCP_NANOLIBC
33 // Unless when building with Nanolibc, we can use absl::string_view directly.
34 using string_view = absl::string_view;
35 #else
36 // TODO(team): Minimal implementation of string_view for bare-metal
37 // environment.
38 struct string_view {};
39 #endif
40 
41 #ifdef FCP_NANOLIBC
42 // TODO(team): Derive these values from tensor.proto built with Nanopb
43 enum DataType {
44   // The constants below should be kept in sync with tensorflow::Datatype:
45   // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto
46   // While not strictly required, that has a number of benefits.
47   DT_INVALID = 0,
48   DT_FLOAT = 1,
49   DT_DOUBLE = 2,
50   DT_INT32 = 3,
51   DT_STRING = 7,
52   DT_INT64 = 9,
53 
54   // TODO(team): Add other types.
55   // This should be a small subset of tensorflow::DataType types and include
56   // only simple numeric types and floating point types.
57   //
58   // When a tensor DT_ type is added here, it must also be added to the list of
59   // MATCH_TYPE_AND_DTYPE macros below and to the CASES macro.
60 };
61 #endif  // FCP_NANOLIBC
62 
63 namespace internal {
64 
65 // This struct is used to map typename T to DataType and specify other traits
66 // of typename T.
67 template <typename T>
68 struct TypeTraits {
69   constexpr static DataType kDataType = DT_INVALID;
70 };
71 
72 #define MATCH_TYPE_AND_DTYPE(TYPE, DTYPE)        \
73   template <>                                    \
74   struct TypeTraits<TYPE> {                      \
75     constexpr static DataType kDataType = DTYPE; \
76   }
77 
78 // Mapping of native types to DT_ types.
79 // TODO(team): Add other types.
80 MATCH_TYPE_AND_DTYPE(float, DT_FLOAT);
81 MATCH_TYPE_AND_DTYPE(double, DT_DOUBLE);
82 MATCH_TYPE_AND_DTYPE(int32_t, DT_INT32);
83 MATCH_TYPE_AND_DTYPE(int64_t, DT_INT64);
84 MATCH_TYPE_AND_DTYPE(string_view, DT_STRING);
85 
86 // The macros DTYPE_CASE and DTYPE_CASES are used to translate Tensor DataType
87 // to strongly typed calls of code parameterized with the template typename
88 // TYPE_ARG.
89 //
90 // For example, let's say there is a function that takes an AggVector<T>:
91 // template <typename T>
92 // void DoSomething(AggVector<T> agg_vector) { ... }
93 //
94 // Given a Tensor, the following code can be used to make a DoSomething call:
95 // DTYPE_CASES(tensor.dtype(), T, DoSomething(tensor.AsAggVector<T>()));
96 //
97 // The second parameter specifies the type argument to be used as the template
98 // parameter in the statement in the third argument.
99 
100 #define SINGLE_ARG(...) __VA_ARGS__
101 #define DTYPE_CASE(TYPE, TYPE_ARG, STMTS)       \
102   case internal::TypeTraits<TYPE>::kDataType: { \
103     typedef TYPE TYPE_ARG;                      \
104     STMTS;                                      \
105     break;                                      \
106   }
107 
108 // TODO(team): Add other types.
109 #define DTYPE_CASES(TYPE_ENUM, TYPE_ARG, STMTS)          \
110   switch (TYPE_ENUM) {                                   \
111     DTYPE_CASE(float, TYPE_ARG, SINGLE_ARG(STMTS))       \
112     DTYPE_CASE(double, TYPE_ARG, SINGLE_ARG(STMTS))      \
113     DTYPE_CASE(int32_t, TYPE_ARG, SINGLE_ARG(STMTS))     \
114     DTYPE_CASE(int64_t, TYPE_ARG, SINGLE_ARG(STMTS))     \
115     DTYPE_CASE(string_view, TYPE_ARG, SINGLE_ARG(STMTS)) \
116     case DT_INVALID:                                     \
117       FCP_LOG(FATAL) << "Invalid type";                  \
118       break;                                             \
119     default:                                             \
120       FCP_LOG(FATAL) << "Unknown type";                  \
121   }
122 
123 }  // namespace internal
124 
125 }  // namespace aggregation
126 }  // namespace fcp
127 
128 #endif  // FCP_AGGREGATION_CORE_DATATYPE_H_
129