xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/full_type_inference_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/full_type_inference_util.h"
17 
18 #include <functional>
19 #include <string>
20 
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/core/framework/full_type.pb.h"
23 #include "tensorflow/core/framework/full_type_util.h"
24 #include "tensorflow/core/framework/op_def_builder.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/platform/statusor.h"
27 #include "tensorflow/core/protobuf/error_codes.pb.h"
28 
29 namespace tensorflow {
30 
31 namespace full_type {
32 
33 // Note about error handling:
34 // For inputs which depend on the correctness of the op definition
35 // (i.e. if the op has three inputs, don't set an `i` that exceeds that),
36 // use DCHECK - an incorrect op def is considered a bug.
37 // Whereas for inputs that depend on the correctness of the graph (i.e. user
38 // used the correct ops), use Status - an incorrect graph is considered a user
39 // error.
40 
KeepExisting()41 ForwardTypeInferenceFn KeepExisting() { return nullptr; }
42 
ReplicateInput(int i,int n)43 ForwardTypeInferenceFn ReplicateInput(int i, int n) {
44   return [i, n](const TypeRefVector& input_types, const TypeRefMap& type_vars) {
45     const FullTypeDef& in_type = input_types.at(i).get();
46     FullTypeDef ret_type;
47     if (in_type.type_id() != TFT_UNSET) {
48       ret_type.set_type_id(TFT_PRODUCT);
49       for (int k = 0; k < n; k++) {
50         *(ret_type.add_args()) = in_type;
51       }
52     }
53     return ret_type;
54   };
55 }
56 
Merge()57 ForwardTypeInferenceFn Merge() {
58   return [](const TypeRefVector& input_types,
59             const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
60     DCHECK(!input_types.empty());
61 
62     FullTypeDef merged;
63     for (int i = 0; i < input_types.size(); i++) {
64       const auto& t = input_types[i].get();
65 
66       if (t.type_id() == TFT_UNSET) {
67         continue;
68       }
69 
70       if (IsSubtype(t, merged)) {
71         merged = t;
72         continue;
73       }
74       if (IsSubtype(merged, t)) {
75         continue;
76       }
77 
78       return Status(error::INVALID_ARGUMENT,
79                     absl::StrCat("expected compatible input types, but input ",
80                                  i, ":\n", t.DebugString(),
81                                  " is neither a subtype nor a supertype of the "
82                                  "combined inputs preceding it:\n",
83                                  merged.DebugString()));
84     }
85 
86     FullTypeDef ret_type;
87     if (merged.type_id() != TFT_UNSET) {
88       ret_type.set_type_id(TFT_PRODUCT);
89       *(ret_type.add_args()) = merged;
90     }
91     return ret_type;
92   };
93 }
94 
Encode(FullTypeId t,int i)95 ForwardTypeInferenceFn Encode(FullTypeId t, int i) {
96   return [t, i](const TypeRefVector& input_types,
97                 const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
98     DCHECK(input_types.size() >= i);
99 
100     FullTypeDef ret_type;
101     const FullTypeDef& in_t = input_types[i].get();
102     if (in_t.type_id() == TFT_UNSET) {
103       return ret_type;
104     }
105 
106     ret_type.set_type_id(TFT_PRODUCT);
107 
108     auto* enc_type = ret_type.add_args();
109     enc_type->set_type_id(TFT_ENCODED);
110     *enc_type->add_args() = in_t;
111     enc_type->add_args()->set_type_id(t);
112     return ret_type;
113   };
114 }
115 
Decode(FullTypeId t,int i)116 ForwardTypeInferenceFn Decode(FullTypeId t, int i) {
117   return [t, i](const TypeRefVector& input_types,
118                 const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
119     DCHECK(input_types.size() >= i);
120 
121     const FullTypeDef& in_t = input_types[i].get();
122 
123     const FullTypeId enc_tid = GetArgDefaultUnset(in_t, 1).type_id();
124     if ((enc_tid != TFT_UNSET) && (enc_tid != t)) {
125       return Status(error::INVALID_ARGUMENT,
126                     absl::StrCat("expected encoded type ", t, " for input ", i,
127                                  ", got ", in_t.DebugString()));
128     }
129 
130     FullTypeDef ret_type;
131 
132     const FullTypeDef& out_t = GetArgDefaultUnset(in_t, 0);
133     if (in_t.type_id() == TFT_UNSET) {
134       return ret_type;
135     }
136 
137     ret_type.set_type_id(TFT_PRODUCT);
138     *ret_type.add_args() = out_t;
139     return ret_type;
140   };
141 }
142 
UnaryContainerCreate(FullTypeId t,int element_idx)143 ForwardTypeInferenceFn UnaryContainerCreate(FullTypeId t, int element_idx) {
144   return
145       [t, element_idx](const TypeRefVector& input_types,
146                        const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
147         DCHECK(input_types.size() >= element_idx);
148 
149         FullTypeDef ret_type;
150         ret_type.set_type_id(TFT_PRODUCT);
151         FullTypeDef* arg_t = ret_type.add_args();
152         arg_t->set_type_id(t);
153         *(arg_t->add_args()) = input_types[element_idx].get();
154 
155         return ret_type;
156       };
157 }
158 
UnaryContainerAdd(FullTypeId t,int container_idx,int element_idx,bool homogeneous)159 ForwardTypeInferenceFn UnaryContainerAdd(FullTypeId t, int container_idx,
160                                          int element_idx, bool homogeneous) {
161   return [t, container_idx, element_idx, homogeneous](
162              const TypeRefVector& input_types,
163              const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
164     DCHECK(input_types.size() >= container_idx);
165     DCHECK(input_types.size() >= element_idx);
166 
167     FullTypeDef ret_type;
168     ret_type.set_type_id(TFT_PRODUCT);
169     FullTypeDef* cont_t = ret_type.add_args();
170     cont_t->set_type_id(t);
171 
172     const FullTypeDef& in_cont_t = input_types[container_idx].get();
173     const FullTypeDef& in_el_t = input_types[element_idx].get();
174 
175     if (in_cont_t.type_id() != TFT_UNSET) {
176       if (in_cont_t.type_id() != t) {
177         return Status(
178             error::INVALID_ARGUMENT,
179             absl::StrCat("expected container type ", t, " for input ",
180                          container_idx, ", got ", in_cont_t.DebugString()));
181       }
182       *cont_t = in_cont_t;
183     }
184 
185     VLOG(1) << "ContainerAddUnary: " << cont_t->DebugString() << ", "
186             << in_el_t.DebugString() << ", " << container_idx << "; "
187             << element_idx;
188     for (const auto& tmp : input_types) {
189       VLOG(1) << "  input: " << tmp.get().DebugString();
190     }
191 
192     if (in_el_t.type_id() == TFT_UNSET) {
193       return ret_type;
194     }
195 
196     const FullTypeDef& el_t = GetArgDefaultUnset(*cont_t, 0);
197 
198     if (el_t.type_id() == TFT_UNSET) {
199       cont_t->clear_args();
200       *(cont_t->add_args()) = in_el_t;
201       return ret_type;
202     }
203 
204     if (IsSubtype(in_el_t, el_t)) {
205       // Nothing to do, will not refine the container type based on a single
206       // addition.
207       return ret_type;
208     }
209 
210     if (homogeneous) {
211       return Status(error::INVALID_ARGUMENT,
212                     absl::StrCat("expected a subtype of ", el_t.DebugString(),
213                                  " for input ", element_idx,
214                                  " of a homogeneous container ", t, ", got ",
215                                  in_el_t.DebugString()));
216     } else {
217       // TODO(mdan): Implement if needed.
218       return Status(
219           error::UNIMPLEMENTED,
220           absl::StrCat("need union types for heterogeneous containers.\n"
221                        "A homogeneous container would expect a subtype of ",
222                        el_t.DebugString(), " for input ", element_idx,
223                        ", but got ", in_el_t.DebugString()));
224     }
225   };
226 }
227 
MultiaryUnstack(FullTypeId t,std::function<FullTypeDef (const FullTypeDef &)> unstack)228 ForwardTypeInferenceFn MultiaryUnstack(
229     FullTypeId t, std::function<FullTypeDef(const FullTypeDef&)> unstack) {
230   return [t, unstack](const TypeRefVector& input_types,
231                       const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
232     FullTypeDef ret_type;
233     ret_type.set_type_id(TFT_PRODUCT);
234     FullTypeDef* cont_t = ret_type.add_args();
235     cont_t->set_type_id(t);
236     FullTypeDef* el_t = cont_t->add_args();
237     el_t->set_type_id(TFT_PRODUCT);
238     for (int element_idx = 0; element_idx < input_types.size(); ++element_idx) {
239       *(el_t->add_args()) = unstack(input_types[element_idx].get());
240     }
241     return ret_type;
242   };
243 }
244 
UnstackTensor(const FullTypeDef & t)245 FullTypeDef UnstackTensor(const FullTypeDef& t) {
246   // For now, only TFT_TENSOR and TFT_RAGGED are supported and
247   // only if they have a single argument (i.e. they don't specify a shape).
248   // If these have a shape in the future, this function needs to changed
249   // so that the output shape is computed based on the input shape and the
250   // effect of the unstack operation (e.g. a dimension is removed).
251   // TFT_UNSET is also allowed to support weak type inference where
252   // not having a fulltype is allowed.
253   DCHECK((t.type_id() == TFT_TENSOR) || (t.type_id() == TFT_RAGGED) ||
254          (t.type_id() == TFT_UNSET));
255   DCHECK_LE(t.args_size(), 1);
256   return t;
257 }
258 
ContainerMap(FullTypeId t,int input_idx,std::function<FullTypeDef (const FullTypeDef &)> map)259 ForwardTypeInferenceFn ContainerMap(
260     FullTypeId t, int input_idx,
261     std::function<FullTypeDef(const FullTypeDef&)> map) {
262   return [t, input_idx, map](
263              const TypeRefVector& input_types,
264              const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
265     DCHECK_GE(input_types.size(), input_idx);
266     const FullTypeDef& in_cont_t = input_types.at(input_idx).get();
267     FullTypeDef ret_type;
268     if (in_cont_t.type_id() == TFT_UNSET) {
269       return ret_type;
270     }
271     if (in_cont_t.type_id() != t) {
272       return Status(error::INVALID_ARGUMENT,
273                     absl::StrCat("expected type ", t, " for input ", input_idx,
274                                  ", got ", in_cont_t.DebugString()));
275     }
276     ret_type.set_type_id(TFT_PRODUCT);
277     FullTypeDef* out_cont_t = ret_type.add_args();
278     out_cont_t->set_type_id(t);
279     const FullTypeDef& in_el_t = GetArgDefaultUnset(in_cont_t, 0);
280     if (in_el_t.type_id() == TFT_UNSET) {
281       return ret_type;
282     }
283     if (in_el_t.type_id() != TFT_PRODUCT) {
284       return Status(error::INVALID_ARGUMENT,
285                     absl::StrCat("expected PRODUCT element type for input ",
286                                  input_idx, ", got ", in_el_t.DebugString()));
287     }
288     FullTypeDef* out_el_t = out_cont_t->add_args();
289     out_el_t->set_type_id(TFT_PRODUCT);
290     for (int k = 0; k < in_el_t.args_size(); k++) {
291       *(out_el_t->add_args()) = map(in_el_t.args(k));
292     }
293     return ret_type;
294   };
295 }
296 
MapCovariant(FullTypeId t,FullTypeId u,int input_idx)297 ForwardTypeInferenceFn MapCovariant(FullTypeId t, FullTypeId u, int input_idx) {
298   return
299       [t, u, input_idx](const TypeRefVector& input_types,
300                         const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
301         DCHECK_GE(input_types.size(), input_idx);
302         const FullTypeDef& in_t = input_types.at(input_idx).get();
303         FullTypeDef ret_type;
304         if (in_t.type_id() == TFT_UNSET) {
305           return ret_type;
306         }
307         if (in_t.type_id() != t) {
308           return Status(error::INVALID_ARGUMENT,
309                         absl::StrCat("expected type ", t, " for input ",
310                                      input_idx, ", got ", in_t.DebugString()));
311         }
312         ret_type.set_type_id(TFT_PRODUCT);
313         FullTypeDef* t = ret_type.add_args();
314         t->set_type_id(u);
315         *t->mutable_args() = in_t.args();
316         return ret_type;
317       };
318 }
319 
BatchTensor(const FullTypeDef & t)320 FullTypeDef BatchTensor(const FullTypeDef& t) {
321   // For now, just return the input type.
322   // If the input type has a shape in the future, this function needs to be
323   // changed so that the output shape is computed based on the input shape and
324   // the effect of the op that changes the batch size (and this function would
325   // require more information to do this computation).
326   return t;
327 }
328 
ShardTensor(const FullTypeDef & t)329 FullTypeDef ShardTensor(const FullTypeDef& t) {
330   // For now, just return the input type.
331   // If the input type has a shape in the future, this function needs to be
332   // changed so that the output shape is computed based on the input shape and
333   // the effect of the op that shards the input into multiple tensors (and this
334   // function would require more information to do this computation).
335   return t;
336 }
337 
338 }  // namespace full_type
339 
340 }  // namespace tensorflow
341