xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/function_testlib.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_
18 
19 #include <string>
20 
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 namespace test {
31 namespace function {
32 
33 // A helper class to make AttrSlice from initializer lists
34 class Attrs {
35  public:
Attrs(const std::initializer_list<std::pair<string,FunctionDefHelper::AttrValueWrapper>> & attrs)36   Attrs(const std::initializer_list<  // NOLINT(runtime/explicit)
37         std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) {
38     for (const auto& aval : attrs) {
39       map_.insert({aval.first, aval.second.proto});
40     }
41   }
42 
Attrs(const std::vector<std::pair<string,FunctionDefHelper::AttrValueWrapper>> & attrs)43   Attrs(
44       const std::vector<std::pair<string, FunctionDefHelper::AttrValueWrapper>>&
45           attrs) {
46     for (const auto& aval : attrs) {
47       map_.insert({aval.first, aval.second.proto});
48     }
49   }
50 
AttrSlice()51   operator AttrSlice() { return AttrSlice(&map_); }  // NOLINT(runtime/explicit)
52 
53  private:
54   AttrValueMap map_;
55 };
56 
57 // Helper to construct a NodeDef.
58 NodeDef NDef(
59     StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs,
60     gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
61         attrs = {},
62     const string& device = "");
63 
64 // Helper to construct a GraphDef proto.
65 GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
66               gtl::ArraySlice<FunctionDef> funcs = {});
67 
68 // For testing convenience, we provide a few simple functions that can
69 // be easily executed and tested.
70 
71 // x: T -> x * 2.
72 FunctionDef XTimesTwo();
73 
74 // x: T -> cpu(x * 2) + cpu(x * 3).
75 FunctionDef TwoDeviceTimesFive();
76 
77 // x: T -> cpu(x * 2), gpu(x * 3).
78 FunctionDef TwoDeviceMult();
79 
80 // cpu(x): T, gpu(y): T -> cpu(x * 2), gpu(y * 3).
81 FunctionDef TwoDeviceInputOutput();
82 
83 // Function taking a list of Tensors as input.
84 FunctionDef FuncWithListInput();
85 
86 // Function returning a list of Tensors as output.
87 FunctionDef FuncWithListOutput();
88 
89 // x: T -> x + x.
90 FunctionDef XAddX();
91 
92 // x: T, y: T -> x + y.
93 FunctionDef XAddY();
94 
95 // x: T -> x * 2, where x is int32.
96 FunctionDef XTimesTwoInt32();
97 
98 // x: T -> (x * 2) * 2.
99 FunctionDef XTimesFour();
100 
101 // x: T -> ((x * 2) * 2) * 2.
102 FunctionDef XTimes16();
103 
104 // w: T, x: T, b: T -> MatMul(w, x) + b
105 FunctionDef WXPlusB();
106 
107 // x: T -> x: T, T is a type which we automatically converts to a bool.
108 FunctionDef NonZero();
109 
110 // x: T -> bool.
111 FunctionDef IsZero();
112 
113 // x: T -> int64
114 FunctionDef RandomUniform();
115 
116 // x: T, y:T  -> y: T, x: T
117 FunctionDef Swap();
118 
119 // x: T, y: T -> y: T, x: T, the body has no nodes.
120 FunctionDef EmptyBodySwap();
121 
122 // x: float, y: resource -> y: resource, 2*x: float.
123 FunctionDef ResourceOutput();
124 
125 // x: resource -> x: resource
126 FunctionDef ResourceIdentity();
127 
128 // x: resource -> y: float.
129 FunctionDef ReadResourceVariable();
130 
131 // Contains simple control flow returning the input via an Enter op.
132 FunctionDef ControlFlow();
133 
134 // Contains malformed control flow which can't be run by the executor.
135 FunctionDef InvalidControlFlow();
136 
137 // x: T -> x <= N.
138 FunctionDef LessThanOrEqualToN(int64_t N);
139 
140 // x: T, y: T -> x + 1, x * y
141 FunctionDef XPlusOneXTimesY();
142 
143 // x: T, y: T -> x <= N
144 FunctionDef XYXLessThanOrEqualToN(int64_t N);
145 
146 // x: T -> bool
147 FunctionDef RandomUniformLess();
148 
149 // start: int64, stop: int64, step: int64 -> y: RangeDatasetOp::Dataset
150 FunctionDef MakeRangeDataset();
151 
152 // input_dataset: variant, batch_size: int64, drop_remainder: bool
153 // -> y: BatchDatasetV2::Dataset
154 FunctionDef MakeBatchDataset();
155 
156 // input_dataset: variant, other_arguments: Targuments, f: func,
157 // Targuments: list(type), output_types: list(type), output_shapes: list(shape),
158 // use_inter_op_parallelism: bool, preserve_cardinality: bool
159 // -> y: MapDatasetOp::Dataset
160 FunctionDef MakeMapDataset(bool has_other_args);
161 
162 // input_dataset: variant, count: int64 -> y: TakeDataset::Dataset
163 FunctionDef MakeTakeDataset();
164 
165 // x: T -> y: TensorSliceDatasetOp::Dataset
166 FunctionDef MakeTensorSliceDataset();
167 
168 // x: T -> y: T, idx: out_idx
169 FunctionDef Unique();
170 
171 void FunctionTestSchedClosure(std::function<void()> fn);
172 
173 }  // end namespace function
174 }  // end namespace test
175 }  // end namespace tensorflow
176 
177 #endif  // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_
178