xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/full_type_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef CORE_FRAMEWORK_FULL_TYPE_UTIL_H_
17 #define CORE_FRAMEWORK_FULL_TYPE_UTIL_H_
18 
19 #include <functional>
20 #include <string>
21 
22 #include "tensorflow/core/framework/full_type.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/op_def.pb.h"
26 #include "tensorflow/core/framework/op_def_builder.h"
27 #include "tensorflow/core/platform/statusor.h"
28 
29 namespace tensorflow {
30 
31 namespace full_type {
32 
33 // TODO(mdan): Specific helpers won't get too far. Use a parser instead.
34 // TODO(mdan): Move constructors into a separate file.
35 
36 // Helpers that allow shorthand expression for the more common kinds of type
37 // constructors.
38 // Note: The arity below refers to the number of arguments of parametric types,
39 // not to the number of return values from a particular op.
40 // Note: Type constructors are meant to create static type definitions in the
41 // op definition (i.e. the OpDef proto).
42 
43 // Helper for a no-op type constructor that indicates that the node's type
44 // should be set by external means (typically by the user).
45 OpTypeConstructor NoOp();
46 
47 // Helper for a trivial type constructor that indicates a node has no
48 // outputs (that is, its output type is an empty TFT_PRODUCT).
49 OpTypeConstructor NoOutputs();
50 
51 // Helper for a type constructor of <t>[] (with no parameters).
52 OpTypeConstructor Nullary(FullTypeId t);
53 
54 // Helper for a type constructor of <t>[FT_VAR[<var_name>]].
55 OpTypeConstructor Unary(FullTypeId t, const string& var_name);
56 
57 // Helper for a type constructor of <t>[FT_ANY].
58 OpTypeConstructor UnaryGeneric(FullTypeId t);
59 
60 // Helper for a type constructor of <t>[FT_TENSOR[<dtype>]].
61 OpTypeConstructor UnaryTensorContainer(FullTypeId t, FullTypeId dtype);
62 
63 // Helper for a type constructor of <t>[FT_VAR[<var_name>]].
64 OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name);
65 
66 // Helper for a type constructor of
67 // <t>[FT_FOR_EACH[
68 //     FT_PRODUCT,
69 //     FT_TENSOR[FT_VAR[<var_name>]],
70 //     FT_VAR[<var_name>]].
71 // Multi-valued type variables will expand the template (see full_type.proto).
72 OpTypeConstructor VariadicTensorContainer(FullTypeId t, const string& var_name);
73 
74 // Type specialization and inference logic. This function narrows the type
75 // specified in an op definition. Such types are usually generic and dependent
76 // on input types. This function resolves the output types based on the input
77 // types specified in a given node def.
78 Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def,
79                       FullTypeDef& target);
80 
81 const FullTypeDef& GetArgDefaultUnset(const FullTypeDef& t, int i);
82 const FullTypeDef& GetArgDefaultAny(const FullTypeDef& t, int i);
83 
84 bool IsEqual(const FullTypeDef& lhs, const FullTypeDef& rhs);
85 
86 bool IsSubtype(const FullTypeDef& lhs, const FullTypeDef& rhs,
87                bool covariant = true);
88 
89 uint64_t Hash(const FullTypeDef& arg);
90 
91 // Determine if the given fulltype is a host memory type.
92 // While it is prefered that Placer (placer.cc and colocation_graph.cc) make
93 // all host memory type placement decisions, any decision made elsewhere
94 // should use this function (e.g. instead of assuming that all variants never
95 // contain host memory types).
IsHostMemoryType(const FullTypeDef & t)96 inline bool IsHostMemoryType(const FullTypeDef& t) {
97   switch (t.type_id()) {
98     case TFT_TENSOR:
99       return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0));
100     case TFT_ARRAY:
101       return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0));
102     case TFT_DATASET:
103       return true;
104     case TFT_MUTEX_LOCK:
105       return true;
106     case TFT_RAGGED:
107       return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0));
108     case TFT_STRING:
109       return true;
110     case TFT_ITERATOR:
111       return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0));
112     case TFT_OPTIONAL:
113       return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0));
114     case TFT_PRODUCT:
115       for (int i = 0; i < t.args_size(); i++) {
116         if (IsHostMemoryType(full_type::GetArgDefaultAny(t, i))) {
117           return true;
118         }
119       }
120       return false;
121     default:
122       return false;
123   }
124 }
125 
126 }  // namespace full_type
127 
128 }  // namespace tensorflow
129 
130 #endif  // CORE_FRAMEWORK_FULL_TYPE_UTIL_H_
131