1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/attr_value.pb.h" 17 #include "tensorflow/core/framework/common_shape_fns.h" 18 #include "tensorflow/core/framework/op.h" 19 #include "tensorflow/core/framework/shape_inference.h" 20 #include "tensorflow/core/framework/tensor_shape.pb.h" 21 #include "tensorflow/core/lib/core/errors.h" 22 23 namespace tensorflow { 24 25 REGISTER_SYSTEM_OP("_Arg") 26 .Output("output: T") 27 .Attr("T: type") 28 .Attr("index: int >= 0") 29 .SetIsStateful() __anon748b2ccf0102(shape_inference::InferenceContext* context) 30 .SetShapeFn([](shape_inference::InferenceContext* context) { 31 const AttrValue* dtype_attr = context->GetAttr("T"); 32 if (!dtype_attr) { 33 return errors::InvalidArgument( 34 "_Arg node does not have attribute \"T\""); 35 } 36 37 const AttrValue* shape_attr = context->GetAttr("_output_shapes"); 38 if (shape_attr && shape_attr->has_list()) { 39 if (shape_attr->list().shape().empty()) { 40 return errors::InvalidArgument( 41 "Invalid \"_output_shapes\" attribute value for _Arg node: ", 42 shape_attr->DebugString()); 43 } 44 const TensorShapeProto& shape_proto = shape_attr->list().shape(0); 45 shape_inference::ShapeHandle shape_handle; 46 TF_RETURN_IF_ERROR( 47 context->MakeShapeFromShapeProto(shape_proto, &shape_handle)); 48 context->set_output(0, shape_handle); 49 } else { 50 context->set_output(0, context->UnknownShape()); 51 } 52 53 if (dtype_attr->type() != DT_RESOURCE) { 54 return OkStatus(); 55 } 56 57 // If the argument is for a resource type, then also try to infer the 58 // type of the tensor store in the resource type. 59 dtype_attr = context->GetAttr("_handle_dtypes"); 60 shape_attr = context->GetAttr("_handle_shapes"); 61 // If either the shape or type attribute is not set then simply return 62 // with unknown output set above. 63 if (!dtype_attr || !shape_attr) { 64 return OkStatus(); 65 } 66 67 if (dtype_attr->list().type().empty()) { 68 return errors::InvalidArgument( 69 "Invalid \"_handle_dtypes\" attribute value for _Arg node: ", 70 dtype_attr->DebugString()); 71 } 72 if (shape_attr->list().shape().empty()) { 73 return errors::InvalidArgument( 74 "Invalid \"_handle_shapes\" attribute value for _Arg node: ", 75 shape_attr->DebugString()); 76 } 77 DataType dtype = dtype_attr->list().type(0); 78 const TensorShapeProto& shape_proto = shape_attr->list().shape(0); 79 shape_inference::ShapeHandle shape_handle; 80 TF_RETURN_IF_ERROR( 81 context->MakeShapeFromShapeProto(shape_proto, &shape_handle)); 82 context->set_output_handle_shapes_and_types( 83 0, std::vector<shape_inference::ShapeAndType>{{shape_handle, dtype}}); 84 return OkStatus(); 85 }) 86 .Doc(R"doc( 87 A graph node which represents an argument to a function. 88 89 output: The argument. 90 index: This argument is the index-th argument of the function. 91 92 Attributes for shape inference: 93 1. _output_shapes: this attribute should contain a list of TensorShapeProto 94 describing the shape(s) of the tensor(s) this _Arg node will produce. If set, 95 _Arg node's shape inference function will use it as the node's output shapes. 96 2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg 97 node producing resource output(s). If set, value of _handle_dtypes should 98 contain the dtype(s) of the resource(s) and value of _handle_shapes should 99 contain the shape(s) of the resource(s). If both attributes are set, _Arg 100 node's shape inference function will use their values as the node's output 101 handle's type(s) and shape(s). 102 )doc"); 103 104 REGISTER_SYSTEM_OP("_DeviceArg") 105 .Output("output: T") 106 .Attr("T: type") 107 .Attr("index: int >= 0") 108 .SetIsStateful() __anon748b2ccf0202(shape_inference::InferenceContext* context) 109 .SetShapeFn([](shape_inference::InferenceContext* context) { 110 context->set_output(0, context->UnknownShape()); 111 return OkStatus(); 112 }) 113 .Doc(R"doc( 114 A graph node which represents an argument to a function. 115 116 output: The argument. 117 index: This argument is the index-th argument of the function. 118 )doc"); 119 120 REGISTER_SYSTEM_OP("_Retval") 121 .Input("input: T") 122 .Attr("T: type") 123 .Attr("index: int >= 0") 124 .SetIsStateful() __anon748b2ccf0302(shape_inference::InferenceContext* context) 125 .SetShapeFn([](shape_inference::InferenceContext* context) { 126 return OkStatus(); 127 }) 128 .Doc(R"doc( 129 A graph node which represents a return value of a function. 130 131 input: The return value. 132 index: This return value is the index-th return value of the function. 133 )doc"); 134 135 REGISTER_SYSTEM_OP("_DeviceRetval") 136 .Input("input: T") 137 .Attr("T: type") 138 .Attr("index: int >= 0") 139 .SetIsStateful() __anon748b2ccf0402(shape_inference::InferenceContext* context) 140 .SetShapeFn([](shape_inference::InferenceContext* context) { 141 return OkStatus(); 142 }) 143 .Doc(R"doc( 144 A graph node which represents a return value of a function. 145 146 input: The return value. 147 index: This return value is the index-th return value of the function. 148 )doc"); 149 150 REGISTER_SYSTEM_OP("_ListToArray") 151 .Input("input: Tin") 152 .Output("output: N * T") 153 .Attr("Tin: list(type)") 154 .Attr("T: type") 155 .Attr("N: int >= 1") 156 .SetShapeFn(shape_inference::UnknownShape) 157 .Doc(R"doc( 158 Converts a list of tensors to an array of tensors. 159 )doc"); 160 161 REGISTER_SYSTEM_OP("_ArrayToList") 162 .Input("input: N * T") 163 .Output("output: out_types") 164 .Attr("T: type") 165 .Attr("N: int >= 1") 166 .Attr("out_types: list(type)") 167 .SetShapeFn(shape_inference::UnknownShape) 168 .Doc(R"doc( 169 Converts an array of tensors to a list of tensors. 170 )doc"); 171 172 } // namespace tensorflow 173