1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/shape_inference.h"
17
18 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
19 #include "tensorflow/core/common_runtime/shape_refiner.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/shape_inference.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/graph/algorithm.h"
25 #include "tensorflow/core/util/dump_graph.h"
26
27 namespace tensorflow {
28
29 namespace {
30
31 // Converts a shape inference handle to a PartialTensorShape.
ShapeHandleToTensorShape(shape_inference::InferenceContext * context,const shape_inference::ShapeHandle & handle,PartialTensorShape * shape)32 Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
33 const shape_inference::ShapeHandle& handle,
34 PartialTensorShape* shape) {
35 // The default is already unknown
36 if (!context->RankKnown(handle)) return OkStatus();
37
38 std::vector<int64_t> dims(context->Rank(handle));
39 for (int32_t i = 0, end = dims.size(); i < end; ++i) {
40 dims[i] = context->Value(context->Dim(handle, i));
41 }
42 return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
43 }
44
PropagateShapes(Graph * graph,const std::map<int,InferredShape> & arg_shapes,const std::vector<BackEdgeHelper::BackEdge> & back_edges,ShapeRefiner * shape_refiner)45 Status PropagateShapes(Graph* graph,
46 const std::map<int, InferredShape>& arg_shapes,
47 const std::vector<BackEdgeHelper::BackEdge>& back_edges,
48 ShapeRefiner* shape_refiner) {
49 std::map<const Node*, const Node*> merge_to_next_iteration;
50 for (const auto& e : back_edges) {
51 if (e.src->IsNextIteration() && e.dst->IsMerge()) {
52 merge_to_next_iteration[e.dst] = e.src;
53 }
54 }
55
56 // Visits the nodes in topological order (reverse post-order), inferring
57 // shapes.
58 // TODO(phawkins): handle cyclic graphs.
59 std::vector<Node*> order;
60 GetReversePostOrder(*graph, &order);
61
62 for (Node* n : order) {
63 // Ignore the status returned by the shape_refiner. We want the best effort
64 // shapes, even if no shape function is registered for a node.
65 Status status = shape_refiner->AddNode(n);
66 if (!status.ok()) {
67 VLOG(1) << "Shape inference failed for node " << n->name() << ": "
68 << status;
69 } else {
70 shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
71 for (int i = 0; i < n->num_outputs(); i++) {
72 shape_inference::ShapeHandle handle = context->output(i);
73 VLOG(4) << "Output " << i << " for node " << n->name() << ": "
74 << context->DebugString(handle);
75 }
76 }
77
78 if (n->type_string() == "_Arg") {
79 int index;
80 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
81 auto it = arg_shapes.find(index);
82 if (it != arg_shapes.end()) {
83 const InferredShape& arg_shape = it->second;
84 shape_inference::InferenceContext* context =
85 shape_refiner->GetContext(n);
86
87 if (arg_shape.handle_type != DT_INVALID) {
88 shape_inference::ShapeHandle handle;
89 TF_RETURN_IF_ERROR(context->MakeShapeFromPartialTensorShape(
90 arg_shape.handle_shape, &handle));
91
92 // Sets the shape and type of the variable's value.
93 context->set_output_handle_shapes_and_types(
94 0, std::vector<shape_inference::ShapeAndType>{
95 {handle, arg_shape.handle_type}});
96 }
97
98 shape_inference::ShapeHandle handle;
99 TF_RETURN_IF_ERROR(
100 context->MakeShapeFromPartialTensorShape(arg_shape.shape, &handle));
101 TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle));
102 }
103 }
104
105 // Sometimes we have VariableShape nodes in while loop (after Enter nodes).
106 // They won't be constant-folded because TensorFlow constant folding does
107 // not handle Enter nodes (and thus does not handle any nodes after Enter
108 // nodes). We try to replace such VariableShape nodes with Const nodes here.
109 if (n->type_string() == "VariableShape") {
110 shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
111 auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
112 if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
113 shape_inference::ShapeHandle handle =
114 handle_shapes_and_types->at(0).shape;
115 TensorShapeProto shape_proto;
116 context->ShapeHandleToProto(handle, &shape_proto);
117 if (!shape_proto.unknown_rank()) {
118 NodeDef const_def;
119 const_def.set_op("Const");
120 Node* var_node;
121 TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
122 const_def.set_name(
123 graph->NewName(absl::StrCat("var_shape_", var_node->name())));
124 DataType dtype = n->output_type(0);
125 AddNodeAttr("dtype", dtype, &const_def);
126 TensorProto value;
127 value.set_dtype(dtype);
128 value.mutable_tensor_shape()->add_dim()->set_size(
129 shape_proto.dim_size());
130 for (const auto& dim : shape_proto.dim()) {
131 if (dtype == DT_INT32) {
132 value.add_int_val(dim.size());
133 } else {
134 value.add_int64_val(dim.size());
135 }
136 }
137 AddNodeAttr("value", value, &const_def);
138 for (auto const& attr : n->attrs()) {
139 if (*attr.first.begin() == '_') {
140 AddNodeAttr(attr.first, attr.second, &const_def);
141 }
142 }
143
144 TF_ASSIGN_OR_RETURN(Node * const_node, graph->AddNode(const_def));
145 graph->AddControlEdge(var_node, const_node);
146 std::vector<const Edge*> out_edges(n->out_edges().begin(),
147 n->out_edges().end());
148 for (const Edge* e : out_edges) {
149 if (e->IsControlEdge()) {
150 graph->AddControlEdge(const_node, e->dst());
151 graph->RemoveEdge(e);
152 } else {
153 Node* dst = e->dst();
154 int dst_input = e->dst_input();
155 graph->RemoveEdge(e);
156 graph->AddEdge(const_node, 0, dst, dst_input);
157 }
158 }
159 }
160 }
161 }
162
163 // Merge node causes a loop so we remove NextIteration->Merge edge before
164 // performing shape inference. But removing those edges also prevents us
165 // from inferring output shape for Merge node (we need shapes for all its
166 // inputs).
167 // For loop invariant resource input's Merge node, we set output resource
168 // shape as Enter node's resource shape.
169 // TODO(b/129367850): clean this up.
170 if (n->IsMerge() && n->output_type(0) == DT_RESOURCE) {
171 // Check if this is a loop invariant input's Merge node. We do it by
172 // checking if corresponding NextIteration node comes from Switch node
173 // directly.
174 auto iter = merge_to_next_iteration.find(n);
175 if (iter != merge_to_next_iteration.end()) {
176 const Node *next_iter = iter->second, *node = next_iter;
177 do {
178 TF_RETURN_IF_ERROR(node->input_node(0, &node));
179 } while (node->IsIdentity());
180 const Node* switch_input;
181 bool is_loop_invariant = node->IsSwitch() &&
182 node->input_node(0, &switch_input).ok() &&
183 switch_input == n;
184 if (is_loop_invariant) {
185 shape_inference::InferenceContext* context =
186 shape_refiner->GetContext(n);
187 for (int i = 0; i < n->num_inputs(); i++) {
188 const Node* input_node;
189 if (n->input_node(i, &input_node).ok()) {
190 auto shapes_and_types = context->input_handle_shapes_and_types(i);
191 if (shapes_and_types) {
192 context->set_output_handle_shapes_and_types(0,
193 *shapes_and_types);
194 }
195 break;
196 }
197 }
198 }
199 }
200 }
201 }
202 return OkStatus();
203 }
204
205 // Store the shapes of the output tensors in a map
StoreOutputShapes(const Graph & graph,const ShapeRefiner & shape_refiner,GraphShapeInfo * shape_info)206 Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner,
207 GraphShapeInfo* shape_info) {
208 for (const Node* node : graph.nodes()) {
209 shape_inference::InferenceContext* context = shape_refiner.GetContext(node);
210 if (!context) continue;
211
212 auto& outputs = (*shape_info)[node->name()];
213 outputs.resize(context->num_outputs());
214 for (int i = 0; i < context->num_outputs(); ++i) {
215 auto& output = outputs[i];
216 TF_RETURN_IF_ERROR(
217 ShapeHandleToTensorShape(context, context->output(i), &output.shape));
218
219 const auto* handle_shapes_and_types =
220 context->output_handle_shapes_and_types(i);
221 if (handle_shapes_and_types != nullptr) {
222 if (handle_shapes_and_types->size() == 1) {
223 TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(
224 context, (*handle_shapes_and_types)[0].shape,
225 &output.handle_shape));
226 output.handle_type = (*handle_shapes_and_types)[0].dtype;
227 } else {
228 // otherwise, it may be resource like a Queue, which can have
229 // multiple shapes and types represented by a single handle.
230 }
231 }
232 VLOG(4) << node->name() << " output " << i << " shape"
233 << output.shape.DebugString() << " handle_type "
234 << DataTypeString(output.handle_type) << " handle_shape "
235 << output.handle_shape.DebugString();
236 }
237 }
238 return OkStatus();
239 }
240
241 } // namespace
242
InferShapes(Graph * graph,const std::map<int,InferredShape> & arg_shapes,const tensorflow::FunctionLibraryDefinition * fnlib_def,GraphShapeInfo * shape_info)243 Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
244 const tensorflow::FunctionLibraryDefinition* fnlib_def,
245 GraphShapeInfo* shape_info) {
246 ShapeRefiner shape_refiner(graph->versions(), graph->op_registry());
247 shape_refiner.set_require_shape_inference_fns(false);
248 // TODO(dlibenzi): Verify if it is worth trying to infer shaped within
249 // functions. Some functions can be called at multiple locations with
250 // difference shapes, which will trigger a shape inference based on the
251 // arguments passed at the first call.
252 // shape_refiner.set_function_library_for_shape_inference(fnlib_def);
253
254 // ShapeRefiner requires that all inputs of a node are present when
255 // ShapeRefiner::AddNode is called. To get at least some shape information in
256 // loops, we temporarily remove loop backedges and add them back again after
257 // the shape inference is complete.
258 BackEdgeHelper back_edge;
259 TF_RETURN_IF_ERROR(back_edge.Remove(graph));
260 TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
261 back_edge.RemovedEdges(), &shape_refiner));
262 TF_RETURN_IF_ERROR(back_edge.Replace());
263
264 // Currently information does not flow "backward" from consumers to producers
265 // in the shape inference, but we consume the shapes in a second pass in case
266 // backward information flow is added in the future.
267 return StoreOutputShapes(*graph, shape_refiner, shape_info);
268 }
269
MergeInferredShapes(const InferredShape & a,const InferredShape & b)270 StatusOr<InferredShape> MergeInferredShapes(const InferredShape& a,
271 const InferredShape& b) {
272 InferredShape result;
273 TF_RETURN_IF_ERROR(a.shape.MergeWith(b.shape, &result.shape));
274
275 if (a.handle_type == DT_INVALID) {
276 result.handle_type = b.handle_type;
277 } else if (b.handle_type == DT_INVALID) {
278 result.handle_type = a.handle_type;
279 } else if (a.handle_type == b.handle_type) {
280 result.handle_type = a.handle_type;
281 } else {
282 return errors::InvalidArgument(
283 "Mismatched resource types: ", DataTypeString(a.handle_type), " vs. ",
284 DataTypeString(b.handle_type));
285 }
286 TF_RETURN_IF_ERROR(
287 a.handle_shape.MergeWith(b.handle_shape, &result.handle_shape));
288 return result;
289 }
290
291 } // namespace tensorflow
292