xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/eval_const_tensor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
17 
18 #include <deque>
19 
20 #include "tensorflow/core/common_runtime/graph_runner.h"
21 #include "tensorflow/core/common_runtime/shape_refiner.h"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/shape_inference.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/versions.pb.h"
27 #include "tensorflow/core/graph/graph.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 using ::tensorflow::shape_inference::InferenceContext;
33 using ::tensorflow::shape_inference::ShapeHandle;
34 
35 // Returns a Tensor containing the underlyiing constant value of a Node if the
36 // node contains a constant value.
EvaluateConstantNode(const Node & node,Tensor * output,bool * success)37 Status EvaluateConstantNode(const Node& node, Tensor* output, bool* success) {
38   *success = false;
39   if (node.IsConstant()) {
40     if (output->FromProto(node.def().attr().at("value").tensor())) {
41       *success = true;
42     }
43   }
44   return OkStatus();
45 }
46 
47 // Returns the int value corresponding to the input src at the i'th edge if the
48 // input src contains a scalar tensor.
EvaluateConstantIntFromScalarEdge(const Node & node,int input_idx,int64 * output,bool * success)49 Status EvaluateConstantIntFromScalarEdge(const Node& node, int input_idx,
50                                          int64* output, bool* success) {
51   *success = false;
52   Tensor scalar;
53   const Edge* edge;
54   TF_RETURN_IF_ERROR(node.input_edge(input_idx, &edge));
55   TF_RETURN_IF_ERROR(EvaluateConstantNode(*edge->src(), &scalar, success));
56   if (success && scalar.NumElements() == 1) {
57     if (scalar.dtype() == DT_INT32) {
58       *output = scalar.scalar<int32>()();
59     } else if (scalar.dtype() == DT_INT64) {
60       *output = scalar.scalar<int64_t>()();
61     } else {
62       *success = false;
63     }
64   }
65   return OkStatus();
66 }
67 
68 // Tries to infer the tensor output based on the input dims of a
69 // Shape node.
70 // [allow_partial = false]
71 //   Can infer the Shape op's output tensor only when the
72 //   input shapes to the Shape op are fully defined.
73 // [allow_partial = true]
74 //   Can infer the Shape op's output tensor as long as the rank of the input
75 //   shapes to the Shape op are known. Uses kUnknownDim for unknown dims.
TryToInferTensorOutputFromShapeNode(const Node & shape_node,InferenceContext * shape_c,Tensor * output,bool * success,bool allow_partial=false)76 Status TryToInferTensorOutputFromShapeNode(const Node& shape_node,
77                                            InferenceContext* shape_c,
78                                            Tensor* output, bool* success,
79                                            bool allow_partial = false) {
80   *success = false;
81   if (shape_node.type_string() != "Shape") return OkStatus();
82   if (shape_c == nullptr) return OkStatus();
83   if (!shape_c->FullyDefined(shape_c->input(0)) && !allow_partial)
84     return OkStatus();
85   if (!shape_c->RankKnown(shape_c->input(0))) return OkStatus();
86 
87   int src_rank = shape_c->Rank(shape_c->input(0));
88   Tensor t(shape_node.output_type(0), TensorShape({src_rank}));
89   if (shape_node.output_type(0) == DT_INT32) {
90     auto flat = t.flat<int>();
91     for (int i = 0; i < src_rank; i++) {
92       int64_t dimension;
93       if (shape_c->ValueKnown(shape_c->Dim(shape_c->input(0), i))) {
94         dimension = shape_c->Value(shape_c->Dim(shape_c->input(0), i));
95         if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
96           return errors::InvalidArgument(
97               "Shape has output type int32, but dimension exceeds maximum "
98               "int32 value");
99         }
100       } else {
101         dimension = shape_c->kUnknownDim;
102       }
103       flat(i) = static_cast<int32>(dimension);
104     }
105   } else if (shape_node.output_type(0) == DT_INT64) {
106     auto flat = t.flat<int64_t>();
107     for (int i = 0; i < src_rank; i++) {
108       if (shape_c->ValueKnown(shape_c->Dim(shape_c->input(0), i))) {
109         flat(i) = shape_c->Value(shape_c->Dim(shape_c->input(0), i));
110       } else {
111         flat(i) = shape_c->kUnknownDim;
112       }
113     }
114   } else {
115     return errors::FailedPrecondition(
116         "Shape has output type that is not int32 or int64");
117   }
118   *output = t;
119   *success = true;
120   return OkStatus();
121 }
122 
123 // Tries to infer the tensor output of a StridedSlice node. This can be done
124 // when taking a slice of a fully defined Shape node or when taking a slice
125 // of partial Shape node along a known dimension.
126 // Examples:
127 //  tf.shape(x)[0]; x.shape = (5, 10) - slicing fully defined shape
128 //  tf.shape(x)[0]; x.shape = (5, ?) - slicing partial shape along known dim
TryToInferTensorOutputFromStridedSliceNode(const Node & node,const ShapeRefiner & refiner,Tensor * output,bool * success)129 Status TryToInferTensorOutputFromStridedSliceNode(const Node& node,
130                                                   const ShapeRefiner& refiner,
131                                                   Tensor* output,
132                                                   bool* success) {
133   *success = false;
134   const Edge* edge;
135   TF_RETURN_IF_ERROR(node.input_edge(0, &edge));
136   const Node* shape_node = edge->src();
137   const Node* stride_node = edge->dst();
138   InferenceContext* shape_c = refiner.GetContext(shape_node);
139   InferenceContext* stride_c = refiner.GetContext(stride_node);
140 
141   if (stride_c == nullptr || shape_c == nullptr) return OkStatus();
142   if (stride_node == nullptr || shape_node == nullptr) return OkStatus();
143   if (stride_node->type_string() != "StridedSlice") return OkStatus();
144   if (shape_node->type_string() != "Shape") return OkStatus();
145 
146   // Only attempt to evaluate if the rank of the inputs to the Shape node are
147   // known.
148   if (!shape_c->RankKnown(shape_c->input(0))) return OkStatus();
149 
150   // Only attempt to evaluate if begin/end/strides values of the StridedSlice
151   // node are all scalars.
152   for (int i = 1; i <= 3; ++i) {
153     ShapeHandle input_shape = stride_c->input(i);
154     if (stride_c->Value(stride_c->Dim(input_shape, 0)) != 1) {
155       return OkStatus();
156     }
157   }
158 
159   // Only attempt to evaluate cases with non-complex masks.
160   int32 begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
161   TF_RETURN_IF_ERROR(stride_c->GetAttr("begin_mask", &begin_mask));
162   TF_RETURN_IF_ERROR(stride_c->GetAttr("end_mask", &end_mask));
163   TF_RETURN_IF_ERROR(stride_c->GetAttr("ellipsis_mask", &ellipsis_mask));
164   TF_RETURN_IF_ERROR(stride_c->GetAttr("new_axis_mask", &new_axis_mask));
165   TF_RETURN_IF_ERROR(stride_c->GetAttr("shrink_axis_mask", &shrink_axis_mask));
166 
167   // Case where user has sliced a single element of a collection. E.g.
168   // collection[i].
169   bool accesses_single_element = begin_mask == 0 && end_mask == 0 &&
170                                  ellipsis_mask == 0 && new_axis_mask == 0 &&
171                                  shrink_axis_mask == 1;
172 
173   if (!accesses_single_element) return OkStatus();
174 
175   // Calculate the output tensor from the Shape node.
176   Tensor shape_output;
177   TF_RETURN_IF_ERROR(TryToInferTensorOutputFromShapeNode(
178       *shape_node, shape_c, &shape_output, success, /*allow_partial=*/true));
179   if (!success) return OkStatus();
180 
181   // Discard the output tensor computed above if the StridedSlice points to an
182   // unknown dimension.
183   int64 begin_value = 0;
184   bool evaluated = false;
185   *success = false;
186   TF_RETURN_IF_ERROR(EvaluateConstantIntFromScalarEdge(
187       *stride_node, 1, &begin_value, &evaluated));
188 
189   if (evaluated && node.output_type(0) == shape_output.dtype()) {
190     begin_value = begin_value < 0
191                       ? begin_value + shape_c->Rank(shape_c->input(0))
192                       : begin_value;
193     Tensor t(node.output_type(0), TensorShape({}));
194     if (shape_output.dtype() == DT_INT32 &&
195         shape_output.flat<int>()(begin_value) != -1) {
196       t.flat<int32>()(0) = shape_output.flat<int>()(begin_value);
197       *output = t;
198       *success = true;
199     } else if (shape_output.dtype() == DT_INT64 &&
200                shape_output.flat<int64_t>()(begin_value) != -1) {
201       t.flat<int64_t>()(0) = shape_output.flat<int64_t>()(begin_value);
202       *output = t;
203       *success = true;
204     }
205   }
206 
207   return OkStatus();
208 }
209 
210 // Tries to infer tensor output based on the input shapes of the node. In some
211 // cases, the shapes of the inputs are sufficient for inferring the contents of
212 // the output tensor. For example, a Shape op with fully defined input shapes
213 // can have its output tensor inferred.
TryToInferTensorOutputFromInputShapes(const Edge & edge,const ShapeRefiner & refiner,Tensor * output,bool * success)214 Status TryToInferTensorOutputFromInputShapes(const Edge& edge,
215                                              const ShapeRefiner& refiner,
216                                              Tensor* output, bool* success) {
217   *success = false;
218   const Node* node = edge.src();
219   InferenceContext* c = refiner.GetContext(node);
220   if (c == nullptr) {
221     // An input without context is a soft failure; we sometimes need to break
222     // control flow loops by running shape inference on a node without first
223     // adding its input.
224     return OkStatus();
225   }
226 
227   if (node->type_string() == "StridedSlice") {
228     TF_RETURN_IF_ERROR(TryToInferTensorOutputFromStridedSliceNode(
229         *node, refiner, output, success));
230   } else if (node->type_string() == "Shape") {
231     // If input shapes to the shape op are fully defined,
232     // we can infer the shape op's output tensor.
233     TF_RETURN_IF_ERROR(
234         TryToInferTensorOutputFromShapeNode(*node, c, output, success));
235   } else if (node->type_string() == "Rank") {
236     bool rank_known = c->RankKnown(c->input(0));
237     if (rank_known) {
238       int32 input_rank = c->Rank(c->input(0));
239       Tensor t(node->output_type(0), TensorShape({}));
240       t.flat<int32>()(0) = input_rank;
241       *output = t;
242       *success = true;
243     }
244   } else if (node->type_string() == "Size") {
245     bool fully_defined_inputs = c->FullyDefined(c->input(0));
246     if (fully_defined_inputs) {
247       int32 rank = c->Rank(c->input(0));
248       Tensor t(node->output_type(0), TensorShape({}));
249       int64 size = 1;
250       for (int i = 0; i < rank; i++) {
251         size *= c->Value(c->Dim(c->input(0), i));
252       }
253       if (node->output_type(0) == DT_INT32) {
254         if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
255           return errors::InvalidArgument(
256               "Size has output type int32, but size exceeds maximum int32 "
257               "value");
258         }
259         t.flat<int32>()(0) = static_cast<int32>(size);
260       } else if (node->output_type(0) == DT_INT64) {
261         t.flat<int64_t>()(0) = size;
262       } else {
263         return errors::FailedPrecondition(
264             "Size has output type that is not int32 or int64");
265       }
266       *output = t;
267       *success = true;
268     }
269   }
270   return OkStatus();
271 }
272 
273 // Returns true if 'node' has a registered CPU kernel.
HasCpuKernel(const Node & node)274 bool HasCpuKernel(const Node& node) {
275   return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr,
276                        /*kernel_class_name=*/nullptr)
277       .ok();
278 }
279 
GetArgNodeIndex(const Node * node,int num_function_inputs,int * index)280 Status GetArgNodeIndex(const Node* node, int num_function_inputs, int* index) {
281   DCHECK(node->IsArg());
282   TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", index));
283   if (*index < 0 || num_function_inputs <= *index) {
284     return errors::Internal(
285         "Function instantiation included invalid input index: ", index,
286         " not in [0, ", num_function_inputs, ").");
287   }
288   return OkStatus();
289 }
290 
291 // Extracts the subgraph ending at 'target_node' that is statically computable
292 // and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
293 // will be set to true.
ExtractConstantSubgraph(const Node & target_node,const ShapeRefiner & refiner,const std::unordered_map<string,Tensor> * cached_values,Graph * out_graph,bool * is_constant_graph,std::vector<std::pair<string,Tensor>> * const_inputs,InferenceContext * outer_context)294 Status ExtractConstantSubgraph(
295     const Node& target_node, const ShapeRefiner& refiner,
296     const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
297     bool* is_constant_graph,
298     std::vector<std::pair<string, Tensor>>* const_inputs,
299     InferenceContext* outer_context) {
300   *is_constant_graph = false;
301   std::unordered_set<string> const_inputs_added;
302   if (target_node.op_def().is_stateful()) {
303     return OkStatus();
304   }
305 
306   if (IsMerge(&target_node)) {
307     return OkStatus();
308   }
309 
310   if (target_node.type_string() == "PlaceholderWithDefault") {
311     return OkStatus();
312   }
313 
314   // Since constant-folding runs on the CPU, do not attempt to constant-fold
315   // operators that have no CPU kernel.
316   if (!HasCpuKernel(target_node)) {
317     return OkStatus();
318   }
319 
320   // TODO(skyewm): should more of the filtering applied in input nodes below be
321   // applied to target_node here?
322 
323   // Identify the possibly constant subgraph by recursively iterating backwards
324   // through the inputs to 'target_node' until we either 1) find an already
325   // existing input to our subgraph 'const_inputs', 2) Discover our graph is not
326   // constant, or 3) Hit a root node.
327 
328   struct NodeAndRecursed {
329     Node* new_node = nullptr;
330     bool recursed = false;
331   };
332 
333   std::map<const Node*, NodeAndRecursed> old_to_new_and_recursed;
334   Node* target_node_copy = out_graph->CopyNode(&target_node);
335   old_to_new_and_recursed[&target_node].new_node = target_node_copy;
336   old_to_new_and_recursed[&target_node].recursed = true;
337 
338   // Add the target node's inputs to seed the recursion.
339   std::deque<const Edge*> edges_to_visit;
340   for (const Edge* e : target_node.in_edges()) {
341     // TODO(skyewm): control edges will be meaningful if/when we handle control
342     // flow (e.g. constants in cond branches are triggered via control edges).
343     if (e->IsControlEdge()) continue;
344     edges_to_visit.push_back(e);
345   }
346 
347   *is_constant_graph = true;
348 
349   // Iterate over the set of edges to visit (backwards).
350   while (!edges_to_visit.empty()) {
351     const Edge* current_edge = edges_to_visit.front();
352     edges_to_visit.pop_front();
353     Node* current_node = current_edge->src();
354 
355     // If the node is stateful, assume the graph is not constant unless it is
356     // an Arg node which is handled later on.
357     if (!current_node->IsArg() && current_node->op_def().is_stateful()) {
358       *is_constant_graph = false;
359       return OkStatus();
360     }
361 
362     // During construction or import from GraphConstructor, back edges may not
363     // be filled in. In addition, control flow constructs may depend on control
364     // edges which aren't handled by this method. Don't constant fold through
365     // merges at all for now.
366     if (IsMerge(current_node)) {
367       *is_constant_graph = false;
368       return OkStatus();
369     }
370 
371     // Don't constant fold enter/exit currently either, as it's easy to end
372     // up with a partial frame.
373     if (IsEnter(current_node) || IsExit(current_node)) {
374       *is_constant_graph = false;
375       return OkStatus();
376     }
377 
378     // Placeholders should never be constant folded because their outputs are
379     // fed by the user. Note that "Placeholder" nodes have no inputs so are
380     // handled below.
381     if (current_node->type_string() == "PlaceholderWithDefault") {
382       *is_constant_graph = false;
383       return OkStatus();
384     }
385 
386     if (!HasCpuKernel(*current_node)) {
387       *is_constant_graph = false;
388       return OkStatus();
389     }
390 
391     // If there is nothing more to recurse down, see if
392     // the generator node is a constant or an Arg node whose value is available
393     // in the `outer_context`.
394     if (current_node->num_inputs() == 0) {
395       if (outer_context && current_node->IsArg()) {
396         const string& tensor_name =
397             strings::StrCat(current_node->name(), ":", 0);
398         // If we do not already have a constant Tensor for this Arg try to
399         // fetch it from the outer context.
400         if (const_inputs_added.count(tensor_name) == 0) {
401           int index;
402           TF_RETURN_IF_ERROR(GetArgNodeIndex(
403               current_node, outer_context->num_inputs(), &index));
404           const Tensor* const_tensor = outer_context->input_tensor(index);
405           if (const_tensor) {
406             const_inputs->emplace_back(tensor_name, *const_tensor);
407             const_inputs_added.insert(tensor_name);
408           } else {
409             // Request a constant value for this Arg. If that is statically
410             // computable, shape refiner will re-run the shape inference for
411             // this function with this tensor's value.
412             outer_context->request_input_tensor(index);
413             *is_constant_graph = false;
414             return OkStatus();
415           }
416         }
417       } else if (!current_node->IsConstant()) {
418         // Generator node is not a constant, so subgraph is not
419         // constant.
420         *is_constant_graph = false;
421         return OkStatus();
422       }
423     }
424 
425     // Either the node is a constant, or the node is a potential
426     // intermediate node on the path from a constant.
427     //
428     // Add a copy of its node and a new edge to the new subgraph.
429 
430     // Get or create the version of 'current_node' in the new graph.
431     Node* current_node_copy;
432     // This gets or creates the NodeAndRecursed entry for current_node.
433     NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
434     if (node_and_recursed->new_node == nullptr) {
435       // First time processing this node.
436       current_node_copy = out_graph->CopyNode(current_node);
437       // Track the mapping from the original node to the new one.
438       node_and_recursed->new_node = current_node_copy;
439     } else {
440       current_node_copy = node_and_recursed->new_node;
441     }
442 
443     // Add the edge to the destination node.
444     {
445       auto it = old_to_new_and_recursed.find(current_edge->dst());
446       if (it == old_to_new_and_recursed.end()) {
447         return errors::Internal(
448             "Could not find mapping from old to new copy of destination node: ",
449             current_edge->dst()->name());
450       }
451       Node* dst_copy = it->second.new_node;
452 
453       out_graph->AddEdge(current_node_copy, current_edge->src_output(),
454                          dst_copy, current_edge->dst_input());
455     }
456 
457     const string& output_tensor_name =
458         strings::StrCat(current_node->name(), ":", current_edge->src_output());
459 
460     // Some tensor values can be inferred. For example, a shape op
461     // with input shapes fully defined can have its output tensor inferred.
462     Tensor tensor_inferred;
463     bool successfully_inferred_tensor = false;
464     TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
465         *current_edge, refiner, &tensor_inferred,
466         &successfully_inferred_tensor));
467     if (successfully_inferred_tensor) {
468       const_inputs->emplace_back(output_tensor_name, tensor_inferred);
469       const_inputs_added.insert(output_tensor_name);
470       continue;
471     }
472 
473     // If we have a copy of the input tensor materialized already,
474     // then add to the list of inputs to feed and do not recurse further.
475     if (cached_values != nullptr) {
476       auto it = cached_values->find(output_tensor_name);
477       if (it != cached_values->end() &&
478           const_inputs_added.count(output_tensor_name) == 0) {
479         const_inputs->emplace_back(output_tensor_name, it->second);
480         const_inputs_added.insert(output_tensor_name);
481         continue;
482       }
483     }
484 
485     // If this node's inputs have not been processed already, do so now.
486     if (!node_and_recursed->recursed) {
487       node_and_recursed->recursed = true;
488       for (const Edge* e : current_node->in_edges()) {
489         if (e->IsControlEdge()) continue;
490         edges_to_visit.push_back(e);
491       }
492     }
493   }
494   return OkStatus();
495 }
496 
497 }  // namespace
498 
EvaluateConstantTensor(OutputTensor tensor,const ShapeRefiner & refiner,const OpRegistryInterface & ops,int32 graph_def_version,bool * evaluated,Tensor * result,GraphRunner * graph_runner,std::unordered_map<string,Tensor> * cached_values,int64 max_cached_value_size,bool disable_constant_propagation,InferenceContext * outer_context)499 Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
500                               const OpRegistryInterface& ops,
501                               int32 graph_def_version, bool* evaluated,
502                               Tensor* result, GraphRunner* graph_runner,
503                               std::unordered_map<string, Tensor>* cached_values,
504                               int64 max_cached_value_size,
505                               bool disable_constant_propagation,
506                               InferenceContext* outer_context) {
507   *evaluated = false;
508   const Node* src = tensor.node;
509 
510   // Simple case: the source node is a constant
511   TF_RETURN_IF_ERROR(EvaluateConstantNode(*src, result, evaluated));
512   if (*evaluated) return OkStatus();
513 
514   // Shape Slice: the source node is slicing a single value of a shape
515   // This is needed to handle the case where the StridedSlice is the only
516   // SubGraph and there are no other subgraphs as in a simple expression such as
517   // tf.shape([-1, 10])[-1] (the ExtractConstantSubgraph call below
518   // only looks at all the input srcs of the various edges; there is never a
519   // chance to evaluate the StridedSlice node as it is never an input src).
520   if (src->type_string() == "StridedSlice") {
521     Tensor slice_output;
522     TF_RETURN_IF_ERROR(TryToInferTensorOutputFromStridedSliceNode(
523         *src, refiner, &slice_output, evaluated));
524     if (*evaluated) {
525       *result = slice_output;
526       return OkStatus();
527     }
528   }
529 
530   // If the source node is an Arg return its value, if available in the outer
531   // context.
532   if (src->IsArg() && outer_context) {
533     int index;
534     TF_RETURN_IF_ERROR(
535         GetArgNodeIndex(src, outer_context->num_inputs(), &index));
536     const Tensor* const_tensor = outer_context->input_tensor(index);
537     if (const_tensor) {
538       *evaluated = true;
539       *result = *(outer_context->input_tensor(index));
540     } else {
541       outer_context->request_input_tensor(index);
542     }
543     return OkStatus();
544   }
545 
546   if (disable_constant_propagation) {
547     return OkStatus();
548   }
549 
550   bool is_constant_graph = false;
551   Graph subgraph(&ops);
552   auto versions = subgraph.versions();
553   versions.set_producer(graph_def_version);
554   subgraph.set_versions(versions);
555 
556   std::vector<std::pair<string, Tensor>> const_inputs;
557   TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
558                                              &subgraph, &is_constant_graph,
559                                              &const_inputs, outer_context));
560   if (!is_constant_graph) {
561     return OkStatus();
562   }
563   const string output_tensor_name =
564       strings::StrCat(src->name(), ":", tensor.index);
565   std::vector<Tensor> outputs;
566 
567   std::unique_ptr<GraphRunner> graph_runner_storage;
568   if (graph_runner == nullptr) {
569     // TODO(skyewm): Convert to std::make_unique when available.
570     graph_runner_storage.reset(new GraphRunner(Env::Default()));
571     graph_runner = graph_runner_storage.get();
572   }
573 
574   // NOTE; we should pass in a function library runtime if we want
575   // to support constant-expression evaluation on functions.
576   Status s = graph_runner->Run(&subgraph, nullptr /* function_library */,
577                                const_inputs, {output_tensor_name}, &outputs);
578 
579   // If all kernels in the constant graph are not registered
580   // in the process, GraphRunner::Run may fail, in which case
581   // we cannot propagate constants, so this is best-effort.
582   if (s.ok()) {
583     *result = outputs[0];
584     *evaluated = true;
585 
586     // We memoize (small) constants evaluated so far, so
587     // ExtractConstantSubgraph can avoid extracting the full
588     // subgraph.  As we build up large graphs, this avoids
589     // repeated computation of the early parts of a constant
590     // graph.
591     if (cached_values != nullptr &&
592         outputs[0].TotalBytes() <= max_cached_value_size) {
593       (*cached_values)[output_tensor_name] = outputs[0];
594     }
595   }
596   return OkStatus();
597 }
598 
599 }  // namespace tensorflow
600