1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/example/feature_util.h"
17
18 #include <string>
19
20 #include "absl/strings/string_view.h"
21
22 namespace tensorflow {
23
24 namespace internal {
ExampleFeature(absl::string_view name,Example * example)25 Feature& ExampleFeature(absl::string_view name, Example* example) {
26 return *GetFeature(name, example);
27 }
28
29 } // namespace internal
30
31 template <>
HasFeature(absl::string_view key,const Features & features)32 bool HasFeature<>(absl::string_view key, const Features& features) {
33 return features.feature().contains(internal::ProtoMapKey(key));
34 }
35
36 template <>
HasFeature(absl::string_view key,const Features & features)37 bool HasFeature<protobuf_int64>(absl::string_view key,
38 const Features& features) {
39 auto it = features.feature().find(internal::ProtoMapKey(key));
40 return (it != features.feature().end()) &&
41 (it->second.kind_case() == Feature::KindCase::kInt64List);
42 }
43
44 template <>
HasFeature(absl::string_view key,const Features & features)45 bool HasFeature<float>(absl::string_view key, const Features& features) {
46 auto it = features.feature().find(internal::ProtoMapKey(key));
47 return (it != features.feature().end()) &&
48 (it->second.kind_case() == Feature::KindCase::kFloatList);
49 }
50
51 template <>
HasFeature(absl::string_view key,const Features & features)52 bool HasFeature<std::string>(absl::string_view key, const Features& features) {
53 auto it = features.feature().find(internal::ProtoMapKey(key));
54 return (it != features.feature().end()) &&
55 (it->second.kind_case() == Feature::KindCase::kBytesList);
56 }
57
58 template <>
HasFeature(absl::string_view key,const Features & features)59 bool HasFeature<tstring>(absl::string_view key, const Features& features) {
60 auto it = features.feature().find(internal::ProtoMapKey(key));
61 return (it != features.feature().end()) &&
62 (it->second.kind_case() == Feature::KindCase::kBytesList);
63 }
64
HasFeatureList(absl::string_view key,const SequenceExample & sequence_example)65 bool HasFeatureList(absl::string_view key,
66 const SequenceExample& sequence_example) {
67 return sequence_example.feature_lists().feature_list().contains(
68 internal::ProtoMapKey(key));
69 }
70
71 template <>
GetFeatureValues(const Feature & feature)72 const protobuf::RepeatedField<protobuf_int64>& GetFeatureValues<protobuf_int64>(
73 const Feature& feature) {
74 return feature.int64_list().value();
75 }
76
77 template <>
GetFeatureValues(Feature * feature)78 protobuf::RepeatedField<protobuf_int64>* GetFeatureValues<protobuf_int64>(
79 Feature* feature) {
80 return feature->mutable_int64_list()->mutable_value();
81 }
82
83 template <>
GetFeatureValues(const Feature & feature)84 const protobuf::RepeatedField<float>& GetFeatureValues<float>(
85 const Feature& feature) {
86 return feature.float_list().value();
87 }
88
89 template <>
GetFeatureValues(Feature * feature)90 protobuf::RepeatedField<float>* GetFeatureValues<float>(Feature* feature) {
91 return feature->mutable_float_list()->mutable_value();
92 }
93
94 template <>
GetFeatureValues(const Feature & feature)95 const protobuf::RepeatedPtrField<std::string>& GetFeatureValues<tstring>(
96 const Feature& feature) {
97 return feature.bytes_list().value();
98 }
99
100 template <>
GetFeatureValues(const Feature & feature)101 const protobuf::RepeatedPtrField<std::string>& GetFeatureValues<std::string>(
102 const Feature& feature) {
103 return feature.bytes_list().value();
104 }
105
106 template <>
GetFeatureValues(Feature * feature)107 protobuf::RepeatedPtrField<std::string>* GetFeatureValues<tstring>(
108 Feature* feature) {
109 return feature->mutable_bytes_list()->mutable_value();
110 }
111
112 template <>
GetFeatureValues(Feature * feature)113 protobuf::RepeatedPtrField<std::string>* GetFeatureValues<std::string>(
114 Feature* feature) {
115 return feature->mutable_bytes_list()->mutable_value();
116 }
117
GetFeatureList(absl::string_view key,const SequenceExample & sequence_example)118 const protobuf::RepeatedPtrField<Feature>& GetFeatureList(
119 absl::string_view key, const SequenceExample& sequence_example) {
120 return sequence_example.feature_lists()
121 .feature_list()
122 .at(internal::ProtoMapKey(key))
123 .feature();
124 }
125
GetFeatureList(absl::string_view feature_list_key,SequenceExample * sequence_example)126 protobuf::RepeatedPtrField<Feature>* GetFeatureList(
127 absl::string_view feature_list_key, SequenceExample* sequence_example) {
128 return (*sequence_example->mutable_feature_lists()
129 ->mutable_feature_list())[internal::ProtoMapKey(
130 feature_list_key)]
131 .mutable_feature();
132 }
133
134 template <>
ClearFeatureValues(Feature * feature)135 void ClearFeatureValues<protobuf_int64>(Feature* feature) {
136 feature->mutable_int64_list()->Clear();
137 }
138
139 template <>
ClearFeatureValues(Feature * feature)140 void ClearFeatureValues<float>(Feature* feature) {
141 feature->mutable_float_list()->Clear();
142 }
143
144 template <>
ClearFeatureValues(Feature * feature)145 void ClearFeatureValues<std::string>(Feature* feature) {
146 feature->mutable_bytes_list()->Clear();
147 }
148
149 template <>
ClearFeatureValues(Feature * feature)150 void ClearFeatureValues<tstring>(Feature* feature) {
151 feature->mutable_bytes_list()->Clear();
152 }
153
154 template <>
GetFeatures(Features * proto)155 Features* GetFeatures<Features>(Features* proto) {
156 return proto;
157 }
158
159 template <>
GetFeatures(Example * proto)160 Features* GetFeatures<Example>(Example* proto) {
161 return proto->mutable_features();
162 }
163
164 template <>
GetFeatures(const Features & proto)165 const Features& GetFeatures<Features>(const Features& proto) {
166 return proto;
167 }
168
169 template <>
GetFeatures(const Example & proto)170 const Features& GetFeatures<Example>(const Example& proto) {
171 return proto.features();
172 }
173
174 template <>
175 const protobuf::RepeatedField<protobuf_int64>& GetFeatureValues<protobuf_int64>(
176 const Feature& feature);
177
178 template <>
179 protobuf::RepeatedField<protobuf_int64>* GetFeatureValues<protobuf_int64>(
180 Feature* feature);
181
182 template <>
183 const protobuf::RepeatedField<float>& GetFeatureValues<float>(
184 const Feature& feature);
185
186 template <>
187 protobuf::RepeatedField<float>* GetFeatureValues<float>(Feature* feature);
188
189 template <>
190 const protobuf::RepeatedPtrField<std::string>& GetFeatureValues<std::string>(
191 const Feature& feature);
192
193 template <>
194 const protobuf::RepeatedPtrField<std::string>& GetFeatureValues<tstring>(
195 const Feature& feature);
196
197 template <>
198 protobuf::RepeatedPtrField<std::string>* GetFeatureValues<std::string>(
199 Feature* feature);
200
201 template <>
202 protobuf::RepeatedPtrField<std::string>* GetFeatureValues<tstring>(
203 Feature* feature);
204
205 } // namespace tensorflow
206