1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ 18 19 #include <string> 20 21 #include "tensorflow/core/framework/attr_value_util.h" 22 #include "tensorflow/core/framework/function.h" 23 #include "tensorflow/core/framework/function.pb.h" 24 #include "tensorflow/core/framework/graph.pb.h" 25 #include "tensorflow/core/framework/node_def.pb.h" 26 #include "tensorflow/core/lib/gtl/array_slice.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 namespace test { 31 namespace function { 32 33 // A helper class to make AttrSlice from initializer lists 34 class Attrs { 35 public: Attrs(const std::initializer_list<std::pair<string,FunctionDefHelper::AttrValueWrapper>> & attrs)36 Attrs(const std::initializer_list< // NOLINT(runtime/explicit) 37 std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) { 38 for (const auto& aval : attrs) { 39 map_.insert({aval.first, aval.second.proto}); 40 } 41 } 42 Attrs(const std::vector<std::pair<string,FunctionDefHelper::AttrValueWrapper>> & attrs)43 Attrs( 44 const std::vector<std::pair<string, FunctionDefHelper::AttrValueWrapper>>& 45 attrs) { 46 for (const auto& aval : attrs) { 47 map_.insert({aval.first, aval.second.proto}); 48 } 49 } 50 AttrSlice()51 operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) 52 53 private: 54 AttrValueMap map_; 55 }; 56 57 // Helper to construct a NodeDef. 58 NodeDef NDef( 59 StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs, 60 gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>> 61 attrs = {}, 62 const string& device = ""); 63 64 // Helper to construct a GraphDef proto. 65 GraphDef GDef(gtl::ArraySlice<NodeDef> nodes, 66 gtl::ArraySlice<FunctionDef> funcs = {}); 67 68 // For testing convenience, we provide a few simple functions that can 69 // be easily executed and tested. 70 71 // x: T -> x * 2. 72 FunctionDef XTimesTwo(); 73 74 // x: T -> cpu(x * 2) + cpu(x * 3). 75 FunctionDef TwoDeviceTimesFive(); 76 77 // x: T -> cpu(x * 2), gpu(x * 3). 78 FunctionDef TwoDeviceMult(); 79 80 // cpu(x): T, gpu(y): T -> cpu(x * 2), gpu(y * 3). 81 FunctionDef TwoDeviceInputOutput(); 82 83 // Function taking a list of Tensors as input. 84 FunctionDef FuncWithListInput(); 85 86 // Function returning a list of Tensors as output. 87 FunctionDef FuncWithListOutput(); 88 89 // x: T -> x + x. 90 FunctionDef XAddX(); 91 92 // x: T, y: T -> x + y. 93 FunctionDef XAddY(); 94 95 // x: T -> x * 2, where x is int32. 96 FunctionDef XTimesTwoInt32(); 97 98 // x: T -> (x * 2) * 2. 99 FunctionDef XTimesFour(); 100 101 // x: T -> ((x * 2) * 2) * 2. 102 FunctionDef XTimes16(); 103 104 // w: T, x: T, b: T -> MatMul(w, x) + b 105 FunctionDef WXPlusB(); 106 107 // x: T -> x: T, T is a type which we automatically converts to a bool. 108 FunctionDef NonZero(); 109 110 // x: T -> bool. 111 FunctionDef IsZero(); 112 113 // x: T -> int64 114 FunctionDef RandomUniform(); 115 116 // x: T, y:T -> y: T, x: T 117 FunctionDef Swap(); 118 119 // x: T, y: T -> y: T, x: T, the body has no nodes. 120 FunctionDef EmptyBodySwap(); 121 122 // x: float, y: resource -> y: resource, 2*x: float. 123 FunctionDef ResourceOutput(); 124 125 // x: resource -> x: resource 126 FunctionDef ResourceIdentity(); 127 128 // x: resource -> y: float. 129 FunctionDef ReadResourceVariable(); 130 131 // Contains simple control flow returning the input via an Enter op. 132 FunctionDef ControlFlow(); 133 134 // Contains malformed control flow which can't be run by the executor. 135 FunctionDef InvalidControlFlow(); 136 137 // x: T -> x <= N. 138 FunctionDef LessThanOrEqualToN(int64_t N); 139 140 // x: T, y: T -> x + 1, x * y 141 FunctionDef XPlusOneXTimesY(); 142 143 // x: T, y: T -> x <= N 144 FunctionDef XYXLessThanOrEqualToN(int64_t N); 145 146 // x: T -> bool 147 FunctionDef RandomUniformLess(); 148 149 // start: int64, stop: int64, step: int64 -> y: RangeDatasetOp::Dataset 150 FunctionDef MakeRangeDataset(); 151 152 // input_dataset: variant, batch_size: int64, drop_remainder: bool 153 // -> y: BatchDatasetV2::Dataset 154 FunctionDef MakeBatchDataset(); 155 156 // input_dataset: variant, other_arguments: Targuments, f: func, 157 // Targuments: list(type), output_types: list(type), output_shapes: list(shape), 158 // use_inter_op_parallelism: bool, preserve_cardinality: bool 159 // -> y: MapDatasetOp::Dataset 160 FunctionDef MakeMapDataset(bool has_other_args); 161 162 // input_dataset: variant, count: int64 -> y: TakeDataset::Dataset 163 FunctionDef MakeTakeDataset(); 164 165 // x: T -> y: TensorSliceDatasetOp::Dataset 166 FunctionDef MakeTensorSliceDataset(); 167 168 // x: T -> y: T, idx: out_idx 169 FunctionDef Unique(); 170 171 void FunctionTestSchedClosure(std::function<void()> fn); 172 173 } // end namespace function 174 } // end namespace test 175 } // end namespace tensorflow 176 177 #endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ 178