xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/function_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/attr_value.pb.h"
17 #include "tensorflow/core/framework/common_shape_fns.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 
23 namespace tensorflow {
24 
25 REGISTER_SYSTEM_OP("_Arg")
26     .Output("output: T")
27     .Attr("T: type")
28     .Attr("index: int >= 0")
29     .SetIsStateful()
__anon748b2ccf0102(shape_inference::InferenceContext* context) 30     .SetShapeFn([](shape_inference::InferenceContext* context) {
31       const AttrValue* dtype_attr = context->GetAttr("T");
32       if (!dtype_attr) {
33         return errors::InvalidArgument(
34             "_Arg node does not have attribute \"T\"");
35       }
36 
37       const AttrValue* shape_attr = context->GetAttr("_output_shapes");
38       if (shape_attr && shape_attr->has_list()) {
39         if (shape_attr->list().shape().empty()) {
40           return errors::InvalidArgument(
41               "Invalid \"_output_shapes\" attribute value for _Arg node: ",
42               shape_attr->DebugString());
43         }
44         const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
45         shape_inference::ShapeHandle shape_handle;
46         TF_RETURN_IF_ERROR(
47             context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
48         context->set_output(0, shape_handle);
49       } else {
50         context->set_output(0, context->UnknownShape());
51       }
52 
53       if (dtype_attr->type() != DT_RESOURCE) {
54         return OkStatus();
55       }
56 
57       // If the argument is for a resource type, then also try to infer the
58       // type of the tensor store in the resource type.
59       dtype_attr = context->GetAttr("_handle_dtypes");
60       shape_attr = context->GetAttr("_handle_shapes");
61       // If either the shape or type attribute is not set then simply return
62       // with unknown output set above.
63       if (!dtype_attr || !shape_attr) {
64         return OkStatus();
65       }
66 
67       if (dtype_attr->list().type().empty()) {
68         return errors::InvalidArgument(
69             "Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
70             dtype_attr->DebugString());
71       }
72       if (shape_attr->list().shape().empty()) {
73         return errors::InvalidArgument(
74             "Invalid \"_handle_shapes\" attribute value for _Arg node: ",
75             shape_attr->DebugString());
76       }
77       DataType dtype = dtype_attr->list().type(0);
78       const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
79       shape_inference::ShapeHandle shape_handle;
80       TF_RETURN_IF_ERROR(
81           context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
82       context->set_output_handle_shapes_and_types(
83           0, std::vector<shape_inference::ShapeAndType>{{shape_handle, dtype}});
84       return OkStatus();
85     })
86     .Doc(R"doc(
87 A graph node which represents an argument to a function.
88 
89 output: The argument.
90 index: This argument is the index-th argument of the function.
91 
92 Attributes for shape inference:
93 1. _output_shapes: this attribute should contain a list of TensorShapeProto
94    describing the shape(s) of the tensor(s) this _Arg node will produce. If set,
95    _Arg node's shape inference function will use it as the node's output shapes.
96 2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg
97    node producing resource output(s). If set, value of _handle_dtypes should
98    contain the dtype(s) of the resource(s) and value of _handle_shapes should
99    contain the shape(s) of the resource(s). If both attributes are set, _Arg
100    node's shape inference function will use their values as the node's output
101    handle's type(s) and shape(s).
102 )doc");
103 
104 REGISTER_SYSTEM_OP("_DeviceArg")
105     .Output("output: T")
106     .Attr("T: type")
107     .Attr("index: int >= 0")
108     .SetIsStateful()
__anon748b2ccf0202(shape_inference::InferenceContext* context) 109     .SetShapeFn([](shape_inference::InferenceContext* context) {
110       context->set_output(0, context->UnknownShape());
111       return OkStatus();
112     })
113     .Doc(R"doc(
114 A graph node which represents an argument to a function.
115 
116 output: The argument.
117 index: This argument is the index-th argument of the function.
118 )doc");
119 
120 REGISTER_SYSTEM_OP("_Retval")
121     .Input("input: T")
122     .Attr("T: type")
123     .Attr("index: int >= 0")
124     .SetIsStateful()
__anon748b2ccf0302(shape_inference::InferenceContext* context) 125     .SetShapeFn([](shape_inference::InferenceContext* context) {
126       return OkStatus();
127     })
128     .Doc(R"doc(
129 A graph node which represents a return value of a function.
130 
131 input: The return value.
132 index: This return value is the index-th return value of the function.
133 )doc");
134 
135 REGISTER_SYSTEM_OP("_DeviceRetval")
136     .Input("input: T")
137     .Attr("T: type")
138     .Attr("index: int >= 0")
139     .SetIsStateful()
__anon748b2ccf0402(shape_inference::InferenceContext* context) 140     .SetShapeFn([](shape_inference::InferenceContext* context) {
141       return OkStatus();
142     })
143     .Doc(R"doc(
144 A graph node which represents a return value of a function.
145 
146 input: The return value.
147 index: This return value is the index-th return value of the function.
148 )doc");
149 
150 REGISTER_SYSTEM_OP("_ListToArray")
151     .Input("input: Tin")
152     .Output("output: N * T")
153     .Attr("Tin: list(type)")
154     .Attr("T: type")
155     .Attr("N: int >= 1")
156     .SetShapeFn(shape_inference::UnknownShape)
157     .Doc(R"doc(
158 Converts a list of tensors to an array of tensors.
159 )doc");
160 
161 REGISTER_SYSTEM_OP("_ArrayToList")
162     .Input("input: N * T")
163     .Output("output: out_types")
164     .Attr("T: type")
165     .Attr("N: int >= 1")
166     .Attr("out_types: list(type)")
167     .SetShapeFn(shape_inference::UnknownShape)
168     .Doc(R"doc(
169 Converts an array of tensors to a list of tensors.
170 )doc");
171 
172 }  // namespace tensorflow
173