xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/shape_inference.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/shape_inference.h"
17 
18 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
19 #include "tensorflow/core/common_runtime/shape_refiner.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/shape_inference.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/graph/algorithm.h"
25 #include "tensorflow/core/util/dump_graph.h"
26 
27 namespace tensorflow {
28 
29 namespace {
30 
31 // Converts a shape inference handle to a PartialTensorShape.
ShapeHandleToTensorShape(shape_inference::InferenceContext * context,const shape_inference::ShapeHandle & handle,PartialTensorShape * shape)32 Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
33                                 const shape_inference::ShapeHandle& handle,
34                                 PartialTensorShape* shape) {
35   // The default is already unknown
36   if (!context->RankKnown(handle)) return OkStatus();
37 
38   std::vector<int64_t> dims(context->Rank(handle));
39   for (int32_t i = 0, end = dims.size(); i < end; ++i) {
40     dims[i] = context->Value(context->Dim(handle, i));
41   }
42   return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
43 }
44 
PropagateShapes(Graph * graph,const std::map<int,InferredShape> & arg_shapes,const std::vector<BackEdgeHelper::BackEdge> & back_edges,ShapeRefiner * shape_refiner)45 Status PropagateShapes(Graph* graph,
46                        const std::map<int, InferredShape>& arg_shapes,
47                        const std::vector<BackEdgeHelper::BackEdge>& back_edges,
48                        ShapeRefiner* shape_refiner) {
49   std::map<const Node*, const Node*> merge_to_next_iteration;
50   for (const auto& e : back_edges) {
51     if (e.src->IsNextIteration() && e.dst->IsMerge()) {
52       merge_to_next_iteration[e.dst] = e.src;
53     }
54   }
55 
56   // Visits the nodes in topological order (reverse post-order), inferring
57   // shapes.
58   // TODO(phawkins): handle cyclic graphs.
59   std::vector<Node*> order;
60   GetReversePostOrder(*graph, &order);
61 
62   for (Node* n : order) {
63     // Ignore the status returned by the shape_refiner. We want the best effort
64     // shapes, even if no shape function is registered for a node.
65     Status status = shape_refiner->AddNode(n);
66     if (!status.ok()) {
67       VLOG(1) << "Shape inference failed for node " << n->name() << ": "
68               << status;
69     } else {
70       shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
71       for (int i = 0; i < n->num_outputs(); i++) {
72         shape_inference::ShapeHandle handle = context->output(i);
73         VLOG(4) << "Output " << i << " for node " << n->name() << ": "
74                 << context->DebugString(handle);
75       }
76     }
77 
78     if (n->type_string() == "_Arg") {
79       int index;
80       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
81       auto it = arg_shapes.find(index);
82       if (it != arg_shapes.end()) {
83         const InferredShape& arg_shape = it->second;
84         shape_inference::InferenceContext* context =
85             shape_refiner->GetContext(n);
86 
87         if (arg_shape.handle_type != DT_INVALID) {
88           shape_inference::ShapeHandle handle;
89           TF_RETURN_IF_ERROR(context->MakeShapeFromPartialTensorShape(
90               arg_shape.handle_shape, &handle));
91 
92           // Sets the shape and type of the variable's value.
93           context->set_output_handle_shapes_and_types(
94               0, std::vector<shape_inference::ShapeAndType>{
95                      {handle, arg_shape.handle_type}});
96         }
97 
98         shape_inference::ShapeHandle handle;
99         TF_RETURN_IF_ERROR(
100             context->MakeShapeFromPartialTensorShape(arg_shape.shape, &handle));
101         TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle));
102       }
103     }
104 
105     // Sometimes we have VariableShape nodes in while loop (after Enter nodes).
106     // They won't be constant-folded because TensorFlow constant folding does
107     // not handle Enter nodes (and thus does not handle any nodes after Enter
108     // nodes). We try to replace such VariableShape nodes with Const nodes here.
109     if (n->type_string() == "VariableShape") {
110       shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
111       auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
112       if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
113         shape_inference::ShapeHandle handle =
114             handle_shapes_and_types->at(0).shape;
115         TensorShapeProto shape_proto;
116         context->ShapeHandleToProto(handle, &shape_proto);
117         if (!shape_proto.unknown_rank()) {
118           NodeDef const_def;
119           const_def.set_op("Const");
120           Node* var_node;
121           TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
122           const_def.set_name(
123               graph->NewName(absl::StrCat("var_shape_", var_node->name())));
124           DataType dtype = n->output_type(0);
125           AddNodeAttr("dtype", dtype, &const_def);
126           TensorProto value;
127           value.set_dtype(dtype);
128           value.mutable_tensor_shape()->add_dim()->set_size(
129               shape_proto.dim_size());
130           for (const auto& dim : shape_proto.dim()) {
131             if (dtype == DT_INT32) {
132               value.add_int_val(dim.size());
133             } else {
134               value.add_int64_val(dim.size());
135             }
136           }
137           AddNodeAttr("value", value, &const_def);
138           for (auto const& attr : n->attrs()) {
139             if (*attr.first.begin() == '_') {
140               AddNodeAttr(attr.first, attr.second, &const_def);
141             }
142           }
143 
144           TF_ASSIGN_OR_RETURN(Node * const_node, graph->AddNode(const_def));
145           graph->AddControlEdge(var_node, const_node);
146           std::vector<const Edge*> out_edges(n->out_edges().begin(),
147                                              n->out_edges().end());
148           for (const Edge* e : out_edges) {
149             if (e->IsControlEdge()) {
150               graph->AddControlEdge(const_node, e->dst());
151               graph->RemoveEdge(e);
152             } else {
153               Node* dst = e->dst();
154               int dst_input = e->dst_input();
155               graph->RemoveEdge(e);
156               graph->AddEdge(const_node, 0, dst, dst_input);
157             }
158           }
159         }
160       }
161     }
162 
163     // Merge node causes a loop so we remove NextIteration->Merge edge before
164     // performing shape inference. But removing those edges also prevents us
165     // from inferring output shape for Merge node (we need shapes for all its
166     // inputs).
167     // For loop invariant resource input's Merge node, we set output resource
168     // shape as Enter node's resource shape.
169     // TODO(b/129367850): clean this up.
170     if (n->IsMerge() && n->output_type(0) == DT_RESOURCE) {
171       // Check if this is a loop invariant input's Merge node. We do it by
172       // checking if corresponding NextIteration node comes from Switch node
173       // directly.
174       auto iter = merge_to_next_iteration.find(n);
175       if (iter != merge_to_next_iteration.end()) {
176         const Node *next_iter = iter->second, *node = next_iter;
177         do {
178           TF_RETURN_IF_ERROR(node->input_node(0, &node));
179         } while (node->IsIdentity());
180         const Node* switch_input;
181         bool is_loop_invariant = node->IsSwitch() &&
182                                  node->input_node(0, &switch_input).ok() &&
183                                  switch_input == n;
184         if (is_loop_invariant) {
185           shape_inference::InferenceContext* context =
186               shape_refiner->GetContext(n);
187           for (int i = 0; i < n->num_inputs(); i++) {
188             const Node* input_node;
189             if (n->input_node(i, &input_node).ok()) {
190               auto shapes_and_types = context->input_handle_shapes_and_types(i);
191               if (shapes_and_types) {
192                 context->set_output_handle_shapes_and_types(0,
193                                                             *shapes_and_types);
194               }
195               break;
196             }
197           }
198         }
199       }
200     }
201   }
202   return OkStatus();
203 }
204 
205 // Store the shapes of the output tensors in a map
StoreOutputShapes(const Graph & graph,const ShapeRefiner & shape_refiner,GraphShapeInfo * shape_info)206 Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner,
207                          GraphShapeInfo* shape_info) {
208   for (const Node* node : graph.nodes()) {
209     shape_inference::InferenceContext* context = shape_refiner.GetContext(node);
210     if (!context) continue;
211 
212     auto& outputs = (*shape_info)[node->name()];
213     outputs.resize(context->num_outputs());
214     for (int i = 0; i < context->num_outputs(); ++i) {
215       auto& output = outputs[i];
216       TF_RETURN_IF_ERROR(
217           ShapeHandleToTensorShape(context, context->output(i), &output.shape));
218 
219       const auto* handle_shapes_and_types =
220           context->output_handle_shapes_and_types(i);
221       if (handle_shapes_and_types != nullptr) {
222         if (handle_shapes_and_types->size() == 1) {
223           TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(
224               context, (*handle_shapes_and_types)[0].shape,
225               &output.handle_shape));
226           output.handle_type = (*handle_shapes_and_types)[0].dtype;
227         } else {
228           // otherwise, it may be resource like a Queue, which can have
229           // multiple shapes and types represented by a single handle.
230         }
231       }
232       VLOG(4) << node->name() << " output " << i << " shape"
233               << output.shape.DebugString() << " handle_type "
234               << DataTypeString(output.handle_type) << " handle_shape "
235               << output.handle_shape.DebugString();
236     }
237   }
238   return OkStatus();
239 }
240 
241 }  // namespace
242 
InferShapes(Graph * graph,const std::map<int,InferredShape> & arg_shapes,const tensorflow::FunctionLibraryDefinition * fnlib_def,GraphShapeInfo * shape_info)243 Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
244                    const tensorflow::FunctionLibraryDefinition* fnlib_def,
245                    GraphShapeInfo* shape_info) {
246   ShapeRefiner shape_refiner(graph->versions(), graph->op_registry());
247   shape_refiner.set_require_shape_inference_fns(false);
248   // TODO(dlibenzi): Verify if it is worth trying to infer shaped within
249   // functions. Some functions can be called at multiple locations with
250   // difference shapes, which will trigger a shape inference based on the
251   // arguments passed at the first call.
252   // shape_refiner.set_function_library_for_shape_inference(fnlib_def);
253 
254   // ShapeRefiner requires that all inputs of a node are present when
255   // ShapeRefiner::AddNode is called. To get at least some shape information in
256   // loops, we temporarily remove loop backedges and add them back again after
257   // the shape inference is complete.
258   BackEdgeHelper back_edge;
259   TF_RETURN_IF_ERROR(back_edge.Remove(graph));
260   TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
261                                      back_edge.RemovedEdges(), &shape_refiner));
262   TF_RETURN_IF_ERROR(back_edge.Replace());
263 
264   // Currently information does not flow "backward" from consumers to producers
265   // in the shape inference, but we consume the shapes in a second pass in case
266   // backward information flow is added in the future.
267   return StoreOutputShapes(*graph, shape_refiner, shape_info);
268 }
269 
MergeInferredShapes(const InferredShape & a,const InferredShape & b)270 StatusOr<InferredShape> MergeInferredShapes(const InferredShape& a,
271                                             const InferredShape& b) {
272   InferredShape result;
273   TF_RETURN_IF_ERROR(a.shape.MergeWith(b.shape, &result.shape));
274 
275   if (a.handle_type == DT_INVALID) {
276     result.handle_type = b.handle_type;
277   } else if (b.handle_type == DT_INVALID) {
278     result.handle_type = a.handle_type;
279   } else if (a.handle_type == b.handle_type) {
280     result.handle_type = a.handle_type;
281   } else {
282     return errors::InvalidArgument(
283         "Mismatched resource types: ", DataTypeString(a.handle_type), " vs. ",
284         DataTypeString(b.handle_type));
285   }
286   TF_RETURN_IF_ERROR(
287       a.handle_shape.MergeWith(b.handle_shape, &result.handle_shape));
288   return result;
289 }
290 
291 }  // namespace tensorflow
292