xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
17 
18 #include <fstream>
19 #include <memory>
20 #include <unordered_map>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/grappler/clusters/cluster.h"
32 #include "tensorflow/core/grappler/costs/virtual_placer.h"
33 #include "tensorflow/core/grappler/devices.h"
34 #include "tensorflow/core/grappler/grappler_item.h"
35 #include "tensorflow/core/grappler/mutable_graph_view.h"
36 #include "tensorflow/core/grappler/op_types.h"
37 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h"
38 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
39 #include "tensorflow/core/grappler/utils.h"
40 #include "tensorflow/core/lib/io/path.h"
41 #include "tensorflow/core/lib/strings/numbers.h"
42 #include "tensorflow/core/lib/strings/str_util.h"
43 #include "tensorflow/core/lib/strings/strcat.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/util/env_var.h"
47 
48 namespace tensorflow {
49 namespace grappler {
50 namespace {
51 
ShouldSimulateGpu()52 bool ShouldSimulateGpu() {
53   bool is_enabled = [] {
54     bool ret = false;
55     string var;
56     TF_CHECK_OK(ReadStringFromEnvVar(
57         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU", "", &var));
58     TF_CHECK_OK(
59         ReadBoolFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU",
60                            /*default_val=*/false, &ret));
61     return ret;
62   }();
63   return is_enabled;
64 }
65 
66 #if GOOGLE_CUDA
67 const std::pair<int, int> kMinGPUArch = {7, 0};
68 #else
69 const std::pair<int, int> kMinGPUArch = {0, 0};
70 #endif
71 
72 const char kSuffix[] = "AutoMixedPrecision";
73 const char kCastToFp16[] = "CastToFp16";
74 const char kCastToBf16[] = "CastToBf16";
75 const char kCastToFp32[] = "CastToFp32";
76 
77 #if GOOGLE_CUDA
78 // Returns the GPU architecture (compute capability) as a (major, minor) pair.
GetDeviceGPUArch(const DeviceProperties & device_properties)79 std::pair<int, int> GetDeviceGPUArch(
80     const DeviceProperties& device_properties) {
81   if (device_properties.type() != "GPU") return {0, 0};
82   string arch_str = device_properties.environment().at("architecture");
83   std::vector<string> split_arch_str = str_util::Split(arch_str, '.');
84   if (split_arch_str.empty()) {
85     return {0, 0};
86   }
87 
88   int major, minor;
89   if (!strings::safe_strto32(split_arch_str[0], &major)) {
90     return {0, 0};
91   }
92 
93   if (split_arch_str.size() > 1) {
94     if (strings::safe_strto32(split_arch_str[1], &minor)) {
95       return {major, minor};
96     } else {
97       return {0, 0};
98     }
99   } else {
100     return {major, 0};
101   }
102 }
103 #endif
104 
105 // Returns true if FP16Support is valid
106 // For CUDA, We compare the GPUArch with the kMinGPUArch, if GPUArch is >= min,
107 // return true. For AMD the corresponding gfx arch string for the detected AMD
108 // GPU is in the list for FP16 supported compute. Returns false otherwise.
HasFastFP16Support(const DeviceProperties & props)109 bool HasFastFP16Support(const DeviceProperties& props) {
110 #if GOOGLE_CUDA
111   return GetDeviceGPUArch(props) >= kMinGPUArch;
112 #elif TENSORFLOW_USE_ROCM
113   absl::flat_hash_set<std::string> FP16SupportedDevices = {{"gfx906"},
114                                                            {"gfx908"}};
115   std::string gcnArchName = props.environment().at("architecture");
116   std::vector<std::string> gpu_arch = absl::StrSplit(gcnArchName, ":");
117   return !gpu_arch.empty() && FP16SupportedDevices.contains(gpu_arch[0]);
118 #endif
119   return ShouldSimulateGpu();
120 }
121 
122 // Instances of this class represent unique type attribute identifiers within a
123 // node. It handles regular type attributes, list type attributes (where
124 // type_index is set to the index in the type list), and fixed types.
125 struct TypeAttrId {
126   static constexpr int kSingleType = -1;
127 
TypeAttrIdtensorflow::grappler::__anon32a4a5c20111::TypeAttrId128   explicit TypeAttrId(const string& _attr_name, int _type_index = kSingleType)
129       : attr_name(_attr_name),
130         type_index(_type_index),
131         fixed_type(DT_INVALID) {}
132 
TypeAttrIdtensorflow::grappler::__anon32a4a5c20111::TypeAttrId133   explicit TypeAttrId(DataType _fixed_type)
134       : attr_name(), type_index(kSingleType), fixed_type(_fixed_type) {}
135 
operator ==tensorflow::grappler::__anon32a4a5c20111::TypeAttrId136   bool operator==(const TypeAttrId& other) const {
137     return attr_name == other.attr_name && type_index == other.type_index &&
138            fixed_type == other.fixed_type;
139   }
140 
operator <tensorflow::grappler::__anon32a4a5c20111::TypeAttrId141   bool operator<(const TypeAttrId& other) const {
142     return std::make_tuple(attr_name, type_index, fixed_type) <
143            std::make_tuple(other.attr_name, other.type_index, other.fixed_type);
144   }
145 
146   template <typename H>
AbslHashValue(H h,const TypeAttrId & ta)147   friend H AbslHashValue(H h, const TypeAttrId& ta) {
148     return H::combine(std::move(h), ta.attr_name, ta.type_index, ta.fixed_type);
149   }
150 
DebugStringtensorflow::grappler::__anon32a4a5c20111::TypeAttrId151   string DebugString() const {
152     if (!attr_name.empty()) {
153       if (type_index == kSingleType) {
154         return attr_name;
155       } else {
156         return strings::StrCat(attr_name, "[", type_index, "]");
157       }
158     } else {
159       return tensorflow::DataTypeString(fixed_type);
160     }
161   }
162 
163   string attr_name;
164   // If attr_name is a list(type), this is the index into the list. Otherwise
165   // this is kSingleType.
166   int type_index;
167   DataType fixed_type;
168 };
169 
170 // Returns the data type of the given type attribute, or DT_INVALID if the type
171 // attribute is invalid.
GetDataType(const NodeDef & node,const TypeAttrId & type_attr)172 DataType GetDataType(const NodeDef& node, const TypeAttrId& type_attr) {
173   if (type_attr.attr_name.empty()) {
174     return type_attr.fixed_type;
175   }
176   if (!node.attr().count(type_attr.attr_name)) {
177     return DT_INVALID;
178   }
179   const AttrValue& attr_value = node.attr().at(type_attr.attr_name);
180   if (type_attr.type_index == TypeAttrId::kSingleType) {
181     return attr_value.type();
182   } else {
183     if (type_attr.type_index < 0 ||
184         type_attr.type_index >= attr_value.list().type_size()) {
185       return DT_INVALID;
186     }
187     return attr_value.list().type(type_attr.type_index);
188   }
189 }
190 
191 // Sets the data type of the given type attribute. Returns false if the type
192 // attribute is invalid, otherwise true.
SetDataType(NodeDef * node,const TypeAttrId & type_attr,DataType type)193 bool SetDataType(NodeDef* node, const TypeAttrId& type_attr, DataType type) {
194   if (type_attr.attr_name.empty() || !node->attr().count(type_attr.attr_name)) {
195     return false;
196   }
197   AttrValue& attr_value = node->mutable_attr()->at(type_attr.attr_name);
198   if (type_attr.type_index == TypeAttrId::kSingleType) {
199     attr_value.set_type(type);
200   } else {
201     if (type_attr.type_index < 0 ||
202         type_attr.type_index >= attr_value.list().type_size()) {
203       return false;
204     }
205     attr_value.mutable_list()->set_type(type_attr.type_index, type);
206   }
207   return true;
208 }
209 
ArgDefIndexes(const NodeDef & node,int arg_idx,const OpDef::ArgDef & arg_def)210 std::vector<std::pair<int, int>> ArgDefIndexes(const NodeDef& node, int arg_idx,
211                                                const OpDef::ArgDef& arg_def) {
212   std::vector<std::pair<int, int>> argdef_inds;
213   if (!arg_def.type_list_attr().empty()) {
214     int num_types = node.attr().at(arg_def.type_list_attr()).list().type_size();
215     for (int type_idx = 0; type_idx < num_types; ++type_idx) {
216       argdef_inds.push_back({arg_idx, type_idx});
217     }
218   } else {
219     int num_repeat = 1;
220     if (node.attr().count(arg_def.number_attr())) {
221       num_repeat = node.attr().at(arg_def.number_attr()).i();
222     }
223     argdef_inds.insert(argdef_inds.end(), num_repeat, {arg_idx, -1});
224   }
225   return argdef_inds;
226 }
227 
228 // Returns a pair (arg_index, type_index) for each input to the node, where
229 // arg_index is the index of the input_arg in op_def and type_index is the index
230 // of the type in type_list_attr (only defined for list arguments).
InputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)231 std::vector<std::pair<int, int>> InputPortArgDefIndexes(const NodeDef& node,
232                                                         const OpDef& op_def) {
233   std::vector<std::pair<int, int>> argdef_inds;
234   argdef_inds.reserve(op_def.input_arg_size());  // Final size may differ.
235   for (int arg_idx = 0; arg_idx < op_def.input_arg_size(); ++arg_idx) {
236     const OpDef::ArgDef& arg_def = op_def.input_arg(arg_idx);
237     auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
238     argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
239                        arg_results.end());
240   }
241   return argdef_inds;
242 }
243 
244 // Returns a pair (arg_index, type_index) for each output to the node, where
245 // arg_index is the index of the output_arg in op_def and type_index is the
246 // index of the type in type_list_attr (only defined for list arguments).
OutputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)247 std::vector<std::pair<int, int>> OutputPortArgDefIndexes(const NodeDef& node,
248                                                          const OpDef& op_def) {
249   std::vector<std::pair<int, int>> argdef_inds;
250   argdef_inds.reserve(op_def.output_arg_size());  // Final size may differ.
251   for (int arg_idx = 0; arg_idx < op_def.output_arg_size(); ++arg_idx) {
252     const OpDef::ArgDef& arg_def = op_def.output_arg(arg_idx);
253     auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
254     argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
255                        arg_results.end());
256   }
257   return argdef_inds;
258 }
259 
GetTypeAttrId(const OpDef::ArgDef & arg_def,int arg_type_index)260 TypeAttrId GetTypeAttrId(const OpDef::ArgDef& arg_def, int arg_type_index) {
261   if (!arg_def.type_list_attr().empty()) {
262     return TypeAttrId(arg_def.type_list_attr(), arg_type_index);
263   } else if (!arg_def.type_attr().empty()) {
264     return TypeAttrId(arg_def.type_attr());
265   } else {
266     return TypeAttrId(arg_def.type());
267   }
268 }
269 
NonControlInputs(const NodeDef & node)270 std::vector<int> NonControlInputs(const NodeDef& node) {
271   std::vector<int> pos;
272   for (int i = 0; i < node.input_size(); i++) {
273     if (!IsControlInput(node.input(i))) {
274       pos.push_back(i);
275     }
276   }
277   return pos;
278 }
279 
280 // A utility class to lookup node type attributes and type attribute <->
281 // input/output port mappings.
282 class NodeTypeAttrMap {
283  public:
NodeTypeAttrMap()284   NodeTypeAttrMap() {}
285 
NodeTypeAttrMap(const GraphDef & graph)286   explicit NodeTypeAttrMap(const GraphDef& graph) { TF_CHECK_OK(Init(graph)); }
287 
Init(const GraphDef & graph)288   Status Init(const GraphDef& graph) {
289     if (graph_ != nullptr) {
290       return errors::InvalidArgument("NodeTypeAttrMap is already initialized.");
291     }
292     graph_ = &graph;
293     function_library_.reset(
294         new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
295     for (const NodeDef& node : graph.node()) {
296       TF_RETURN_IF_ERROR(AddNode(node));
297     }
298     return OkStatus();
299   }
300 
is_initialized() const301   bool is_initialized() const { return graph_ != nullptr; }
302 
303   // Returns the set of all type attributes in the given node.
GetTypeAttrs(const NodeDef & node) const304   absl::flat_hash_set<TypeAttrId> GetTypeAttrs(const NodeDef& node) const {
305     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
306     absl::flat_hash_set<TypeAttrId> type_attrs;
307     const auto iter = type2io_.find(&node);
308     CHECK(iter != type2io_.end());  // Crash Ok
309     for (const auto& key_value : iter->second) {
310       type_attrs.insert(key_value.first);
311     }
312     return type_attrs;
313   }
314 
GetInputPorts(const NodeDef & node,const TypeAttrId & type_attr) const315   const absl::flat_hash_set<int>& GetInputPorts(
316       const NodeDef& node, const TypeAttrId& type_attr) const {
317     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
318     return type2io_.at(&node).at(type_attr).first;
319   }
320 
GetOutputPorts(const NodeDef & node,const TypeAttrId & type_attr) const321   const absl::flat_hash_set<int>& GetOutputPorts(
322       const NodeDef& node, const TypeAttrId& type_attr) const {
323     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
324     return type2io_.at(&node).at(type_attr).second;
325   }
326 
GetInputTypeAttr(const NodeDef & node,int port) const327   TypeAttrId GetInputTypeAttr(const NodeDef& node, int port) const {
328     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
329     const auto iter = io2type_.find(&node);
330     DCHECK(iter != io2type_.end())
331         << "Node " << node.name() << " doesn't exist in a graph";
332     auto type_vec = io2type_.at(&node).first;
333     CHECK_GE(port, 0);                // Crash Ok
334     CHECK_LT(port, type_vec.size());  // Crash Ok
335     return type_vec[port];
336   }
337 
GetOutputTypeAttr(const NodeDef & node,int port) const338   TypeAttrId GetOutputTypeAttr(const NodeDef& node, int port) const {
339     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
340     auto type_vec = io2type_.at(&node).second;
341     CHECK_GE(port, 0);                // Crash Ok
342     CHECK_LT(port, type_vec.size());  // Crash Ok
343     return type_vec[port];
344   }
345 
346  private:
AddNode(const NodeDef & node)347   Status AddNode(const NodeDef& node) {
348     const OpDef* op_def_ptr = nullptr;
349     TF_RETURN_IF_ERROR(function_library_->LookUpOpDef(node.op(), &op_def_ptr));
350     const OpDef& op_def = *op_def_ptr;
351     auto& type2io_entry = type2io_[&node];
352     auto& io2type_entry = io2type_[&node];
353     auto input_arg_inds = InputPortArgDefIndexes(node, op_def);
354     if (NonControlInputs(node).size() != input_arg_inds.size()) {
355       return errors::InvalidArgument(
356           "Expected ", node.op(), " node ", node.name(), " to have ",
357           input_arg_inds.size(), " non-control input(s), but got ",
358           node.input_size());
359     }
360     // Note that the mappings generated here include inputs/outputs with fixed
361     // types. This makes the mappings complete (all inputs and outputs are
362     // included), and allows the graph rewriter to propagate deny paint
363     // from/through ops with fixed types.
364     io2type_entry.first.reserve(input_arg_inds.size());
365     for (int i = 0; i < static_cast<int>(input_arg_inds.size()); ++i) {
366       const auto& arg_inds = input_arg_inds[i];
367       const OpDef::ArgDef& arg_def = op_def.input_arg(arg_inds.first);
368       TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
369       if (!type_attr.attr_name.empty() &&
370           !node.attr().count(type_attr.attr_name)) {
371         return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
372                                        " is not present in node ", node.name());
373       }
374       type2io_entry[type_attr].first.insert(i);
375       io2type_entry.first.push_back(type_attr);
376     }
377 
378     auto output_arg_inds = OutputPortArgDefIndexes(node, op_def);
379     io2type_entry.second.reserve(output_arg_inds.size());
380     for (int i = 0; i < static_cast<int>(output_arg_inds.size()); ++i) {
381       const auto& arg_inds = output_arg_inds[i];
382       const OpDef::ArgDef& arg_def = op_def.output_arg(arg_inds.first);
383       TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
384       if (!type_attr.attr_name.empty() &&
385           !node.attr().count(type_attr.attr_name)) {
386         return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
387                                        " is not present in node ", node.name());
388       }
389       type2io_entry[type_attr].second.insert(i);
390       io2type_entry.second.push_back(type_attr);
391     }
392 
393     // Also ensure that type attributes that aren't associated with any inputs
394     // or outputs (e.g., StackV2's elem_type) are added to the map.
395     for (const auto& attr : node.attr()) {
396       const string& attr_name = attr.first;
397       if (!attr_name.empty() && attr_name[0] == '_') continue;
398       const AttrValue& attr_value = attr.second;
399       const OpDef::AttrDef* attr_def = FindAttr(attr_name, op_def);
400       if (!attr_def) {
401         return errors::InvalidArgument("AttrDef not found for attribute ",
402                                        attr_name, " of node ", node.name());
403       }
404       if (attr_def->type() == "type") {
405         type2io_entry[TypeAttrId(attr_name)];
406       } else if (attr_def->type() == "list(type)") {
407         for (int i = 0; i < attr_value.list().type_size(); ++i) {
408           type2io_entry[TypeAttrId(attr_name, i)];
409         }
410       }
411     }
412     return OkStatus();
413   }
414 
415   // WARN: `graph_` must outlive this object (node pointers must remain valid).
416   const GraphDef* graph_ = nullptr;  // do not own
417   std::unique_ptr<FunctionLibraryDefinition> function_library_;
418 
419   typedef absl::flat_hash_set<int> IntSet;
420   // Maps a type attr id -> (input port set, output port set)
421   typedef absl::flat_hash_map<TypeAttrId, std::pair<IntSet, IntSet>> Type2IOMap;
422   // Maps a node -> type attr mapping
423   absl::flat_hash_map<const NodeDef*, Type2IOMap> type2io_;
424   // Maps a port -> type attr id
425   typedef std::vector<TypeAttrId> TypeAttrIdVec;
426   // Maps a node -> (input port mapping, output port mapping)
427   absl::flat_hash_map<const NodeDef*, std::pair<TypeAttrIdVec, TypeAttrIdVec>>
428       io2type_;
429 };
430 
431 struct NodeTypeId {
NodeTypeIdtensorflow::grappler::__anon32a4a5c20111::NodeTypeId432   NodeTypeId(const NodeDef* _node, const TypeAttrId& _type_attr)
433       : node(_node), type_attr(_type_attr) {}
434 
435   const NodeDef* node;
436   TypeAttrId type_attr;
437 
operator ==tensorflow::grappler::__anon32a4a5c20111::NodeTypeId438   bool operator==(const NodeTypeId& other) const {
439     return node == other.node && type_attr == other.type_attr;
440   }
441 
442   template <typename H>
AbslHashValue(H h,const NodeTypeId & nt)443   friend H AbslHashValue(H h, const NodeTypeId& nt) {
444     return H::combine(std::move(h), nt.node, nt.type_attr);
445   }
446 };
447 
448 struct NodeTypeIdEdge {
NodeTypeIdEdgetensorflow::grappler::__anon32a4a5c20111::NodeTypeIdEdge449   NodeTypeIdEdge(const NodeTypeId& _src, const NodeTypeId& _dst)
450       : src(_src), dst(_dst) {}
451   NodeTypeId src;
452   NodeTypeId dst;
453 };
454 
455 // TODO(benbarsdell): Investigate whether the existing GraphTopologyView can be
456 // used instead of this modified version.
457 // This is just like GraphTopologyView but with (NodeDef, TypeAttrId) pairs as
458 // the vertices instead of just NodeDef.
459 // For example, if node A has output A:0 with TypeAttrId 'T', and node B has
460 // input B:0 with TypeAttrId 'U', and input B:0 connects to output A:0, there
461 // will be an edge from (A, T) to (B, U).
462 class GraphTypeTopologyView {
463  public:
464   GraphTypeTopologyView() = default;
GraphTypeTopologyView(bool skip_invalid_edges)465   explicit GraphTypeTopologyView(bool skip_invalid_edges)
466       : skip_invalid_edges_(skip_invalid_edges) {}
467 
468   // Initialize graph topology view from the graph. It's possible to pass
469   // additional edges that do not exist in a graph, but must be respected when
470   // computing graph topology. Example: Tensorflow runtime allows concurrent
471   // execution of dequeue/enqueue ops from the same queue resource, but we might
472   // want to enforce ordering between them for the purpose of graph analysis.
473   Status InitializeFromGraph(const GraphDef& graph,
474                              const NodeTypeAttrMap& node_type_map);
475 
476   Status AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges);
477 
is_initialized() const478   bool is_initialized() const { return graph_ != nullptr; }
num_nodes() const479   int num_nodes() const { return num_nodes_; }
graph() const480   const GraphDef* graph() const { return graph_; }
481 
482   // Returns true iff the node exists in the underlying graph.
483   bool HasNode(absl::string_view node_name, const TypeAttrId& type_attr) const;
484 
485   // Finds a node by name or returns `nullptr` if it's not in the graph.
486   const NodeTypeId* GetNode(absl::string_view node_name,
487                             const TypeAttrId& type_attr) const;
488   // Returns a node corresponding to the given node index.
489   const NodeTypeId* GetNode(int node_idx) const;
490 
491   // Returns a node index for the given node name, if the name exists in the
492   // underlying graph. Otherwise returns empty optional.
493   const absl::optional<int> GetNodeIndex(absl::string_view node_name,
494                                          const TypeAttrId& type_attr) const;
495   // Returns a node index for the given node, if the node belongs to the
496   // underlying graph. Otherwise returns empty optional.
497   const absl::optional<int> GetNodeIndex(const NodeTypeId& node) const;
498 
499   // Returns all the node indexes that are in the direct fanin of the given
500   // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
501   const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const;
502   // Returns all the node indexes that are in the direct fanout of the given
503   // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
504   const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const;
505 
506  private:
507   // The key type used to uniquely identify a type attribute on a node.
508   struct NodeTypeKey : public std::pair<absl::string_view, TypeAttrId> {
509     typedef std::pair<absl::string_view, TypeAttrId> Base;
510 
511     // Inherit the set of constructors.
512     using Base::pair;
513 
514     template <typename H>
AbslHashValue(H h,const NodeTypeKey & nt)515     friend H AbslHashValue(H h, const NodeTypeKey& nt) {
516       return H::combine(std::move(h), nt.first, nt.second);
517     }
518   };
519 
520   // If true, all invalid edges and inputs (srd, dst or input node not found in
521   // a graph) will be skipped, otherwise initialization will fail with error.
522   bool skip_invalid_edges_ = false;
523 
524   // WARN: `graph_` must outlive this object and graph nodes must not be
525   // destructed, because node names captured with absl::string_view.
526   const GraphDef* graph_ = nullptr;  // do not own
527   int num_nodes_ = 0;
528   std::vector<NodeTypeId> node_type_attrs_;
529   absl::flat_hash_map<absl::string_view, int> node_name_to_index_;
530   absl::flat_hash_map<NodeTypeKey, int> node_type_name_to_index_;
531 
532   std::vector<absl::InlinedVector<int, 4>> fanins_;
533   std::vector<absl::InlinedVector<int, 2>> fanouts_;
534 
535   // We need a valid reference to return from GetFanin/GetFanout if the
536   // `node_idx` argument is outside of the [0, num_nodes_) range.
537   absl::InlinedVector<int, 4> empty_fanin_;
538   absl::InlinedVector<int, 2> empty_fanout_;
539 };
540 
541 template <typename T>
SortAndRemoveDuplicates(T * v)542 inline void SortAndRemoveDuplicates(T* v) {
543   std::sort(v->begin(), v->end());
544   v->erase(std::unique(v->begin(), v->end()), v->end());
545 }
546 
InitializeFromGraph(const GraphDef & graph,const NodeTypeAttrMap & node_type_map)547 Status GraphTypeTopologyView::InitializeFromGraph(
548     const GraphDef& graph, const NodeTypeAttrMap& node_type_map) {
549   if (graph_ != nullptr) {
550     return errors::InvalidArgument(
551         "GraphTypeTopologyView is already initialized.");
552   }
553 
554   graph_ = &graph;
555   int num_nodedefs = graph.node_size();
556   node_name_to_index_.rehash(num_nodedefs);
557 
558   // Build maps from name to index.
559   node_type_attrs_.reserve(num_nodedefs);         // Only approximate.
560   node_type_name_to_index_.rehash(num_nodedefs);  // Only approximate.
561   for (int node_idx = 0; node_idx < num_nodedefs; ++node_idx) {
562     const NodeDef& node = graph.node(node_idx);
563     node_name_to_index_.emplace(node.name(), node_idx);
564 
565     for (const TypeAttrId& type_attr : node_type_map.GetTypeAttrs(node)) {
566       int node_type_idx = node_type_attrs_.size();
567       node_type_name_to_index_.emplace(NodeTypeKey(node.name(), type_attr),
568                                        node_type_idx);
569       node_type_attrs_.emplace_back(&node, type_attr);
570     }
571   }
572   num_nodes_ = node_type_attrs_.size();
573   fanins_.resize(num_nodes_);
574   fanouts_.resize(num_nodes_);
575 
576   // Add graph edges to the adjacency lists.
577   for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
578     const NodeTypeId& node_type = node_type_attrs_.at(node_type_idx);
579     auto input_ports =
580         node_type_map.GetInputPorts(*node_type.node, node_type.type_attr);
581     fanins_[node_type_idx].reserve(input_ports.size());
582     for (int port : input_ports) {
583       const string& input = node_type.node->input(port);
584       TensorId tensor = ParseTensorName(input);
585       const auto it = node_name_to_index_.find(tensor.node());
586       const bool valid_input = it != node_name_to_index_.end();
587 
588       if (!valid_input) {
589         const string error_message = absl::StrCat(
590             "Non-existent input ", input, " in node ", node_type.node->name());
591         if (skip_invalid_edges_) {
592           VLOG(3) << "Skip error: " << error_message;
593         } else {
594           return errors::InvalidArgument(error_message);
595         }
596       }
597 
598       if (valid_input) {
599         const int input_idx = it->second;
600         const NodeDef& input_node = graph_->node(input_idx);
601         TypeAttrId input_type_attr =
602             node_type_map.GetOutputTypeAttr(input_node, tensor.index());
603         const auto it2 = node_type_name_to_index_.find(
604             NodeTypeKey(input_node.name(), input_type_attr));
605         if (it2 == node_type_name_to_index_.end()) {
606           if (!skip_invalid_edges_) {
607             return errors::InvalidArgument("Did not find type attr ",
608                                            input_type_attr.DebugString(),
609                                            " in node ", input_node.name());
610           }
611           continue;
612         }
613         int input_node_type_idx = it2->second;
614         fanins_[node_type_idx].push_back(input_node_type_idx);
615         fanouts_[input_node_type_idx].push_back(node_type_idx);
616       }
617     }
618 
619     // Dedup the input list while it's still hot in cache.
620     SortAndRemoveDuplicates(&fanins_[node_type_idx]);
621   }
622 
623   // Dedup outputs for all the graph nodes.
624   for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
625     SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
626   }
627 
628   return OkStatus();
629 }
630 
AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges)631 Status GraphTypeTopologyView::AddEphemeralEdges(
632     absl::Span<const NodeTypeIdEdge> ephemeral_edges) {
633   // Add ephemeral edges to the adjacency lists.
634   for (const NodeTypeIdEdge& edge : ephemeral_edges) {
635     const auto src = node_name_to_index_.find(edge.src.node->name());
636     const bool valid_src = src != node_name_to_index_.end();
637 
638     if (!valid_src) {
639       const string error_message =
640           absl::StrCat("Non-existent src node: ", edge.src.node->name());
641       if (skip_invalid_edges_) {
642         VLOG(0) << "Skip error: " << error_message;
643       } else {
644         return errors::InvalidArgument(error_message);
645       }
646     }
647 
648     const auto dst = node_name_to_index_.find(edge.dst.node->name());
649     const bool valid_dst = dst != node_name_to_index_.end();
650 
651     if (!valid_dst) {
652       const string error_message =
653           absl::StrCat("Non-existent dst node: ", edge.dst.node->name());
654       if (skip_invalid_edges_) {
655         VLOG(0) << "Skip error: " << error_message;
656       } else {
657         return errors::InvalidArgument(error_message);
658       }
659     }
660 
661     if (valid_dst && valid_src) {
662       // TODO(benbarsdell): Check for failure.
663       int src_node_type_idx = node_type_name_to_index_.at(
664           NodeTypeKey(edge.src.node->name(), edge.src.type_attr));
665       int dst_node_type_idx = node_type_name_to_index_.at(
666           NodeTypeKey(edge.dst.node->name(), edge.dst.type_attr));
667       fanins_[dst_node_type_idx].push_back(src_node_type_idx);
668       fanouts_[src_node_type_idx].push_back(dst_node_type_idx);
669     }
670   }
671 
672   // Dedup inputs and outputs for all the graph nodes.
673   for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
674     SortAndRemoveDuplicates(&fanins_[node_type_idx]);
675     SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
676   }
677 
678   return OkStatus();
679 }
680 
HasNode(absl::string_view node_name,const TypeAttrId & type_attr) const681 bool GraphTypeTopologyView::HasNode(absl::string_view node_name,
682                                     const TypeAttrId& type_attr) const {
683   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
684   NodeTypeKey key(node_name, type_attr);
685   const auto it = node_type_name_to_index_.find(key);
686   return it != node_type_name_to_index_.end();
687 }
688 
GetNode(absl::string_view node_name,const TypeAttrId & type_attr) const689 const NodeTypeId* GraphTypeTopologyView::GetNode(
690     absl::string_view node_name, const TypeAttrId& type_attr) const {
691   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
692   NodeTypeKey key(node_name, type_attr);
693   const auto it = node_type_name_to_index_.find(key);
694   return it == node_type_name_to_index_.end()
695              ? nullptr
696              : &node_type_attrs_.at(it->second);
697 }
698 
GetNode(int node_idx) const699 const NodeTypeId* GraphTypeTopologyView::GetNode(int node_idx) const {
700   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
701   DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range";
702   return &node_type_attrs_.at(node_idx);
703 }
704 
GetNodeIndex(absl::string_view node_name,const TypeAttrId & type_attr) const705 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
706     absl::string_view node_name, const TypeAttrId& type_attr) const {
707   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
708   NodeTypeKey key(node_name, type_attr);
709   const auto it = node_type_name_to_index_.find(key);
710   DCHECK(it != node_type_name_to_index_.end())
711       << "Node doesn't exist in a graph";
712   return it == node_type_name_to_index_.end() ? absl::nullopt
713                                               : absl::make_optional(it->second);
714 }
715 
GetNodeIndex(const NodeTypeId & node) const716 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
717     const NodeTypeId& node) const {
718   return GetNodeIndex(node.node->name(), node.type_attr);
719 }
720 
GetFanin(int node_idx) const721 const absl::InlinedVector<int, 4>& GraphTypeTopologyView::GetFanin(
722     int node_idx) const {
723   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
724   const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
725   DCHECK(is_valid_node_idx) << "node_idx is out of range";
726   return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_;
727 }
728 
GetFanout(int node_idx) const729 const absl::InlinedVector<int, 2>& GraphTypeTopologyView::GetFanout(
730     int node_idx) const {
731   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
732   const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
733   DCHECK(is_valid_node_idx) << "node_idx is out of range";
734   return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_;
735 }
736 
737 enum class TypeTraversalDirection {
738   kFollowInputs,
739   kFollowOutputs,
740   kFollowInputsAndOutputs,
741 };
742 
743 // Encapsulate DFS callbacks that will be called during the graph traversal.
744 //
745 // If non-empty, the `pre_order` and `post_order` functors will be called on
746 // each reachable node (including the `from` nodes) in pre and post order. If
747 // loops are found, the `on_back_edge` functor will be called on the
748 // corresponding back edges. Moreover, the pre and post order will assume that
749 // these back edges will be cut.
750 struct DfsTypeCallbacks {
751   DfsTypeCallbacks() = default;
DfsTypeCallbackstensorflow::grappler::__anon32a4a5c20111::DfsTypeCallbacks752   DfsTypeCallbacks(std::function<void(int)> pre, std::function<void(int)> post,
753                    std::function<void(int, int)> back_edge)
754       : pre_order(std::move(pre)),
755         post_order(std::move(post)),
756         on_back_edge(std::move(back_edge)) {}
757 
PreOrdertensorflow::grappler::__anon32a4a5c20111::DfsTypeCallbacks758   static DfsTypeCallbacks PreOrder(std::function<void(int)> pre) {
759     return DfsTypeCallbacks(std::move(pre), nullptr, nullptr);
760   }
761 
PostOrdertensorflow::grappler::__anon32a4a5c20111::DfsTypeCallbacks762   static DfsTypeCallbacks PostOrder(std::function<void(int)> post) {
763     return DfsTypeCallbacks(nullptr, std::move(post), nullptr);
764   }
765 
766   std::function<void(int)> pre_order;
767   std::function<void(int)> post_order;
768   std::function<void(int, int)> on_back_edge;
769 };
770 
771 // Encapsulate DFS predicates for traversing the graph.
772 //
773 // The `enter` predicate decides if traversal should enter the node, and the
774 // `advance` predicate decides if the traversal should follow inputs/outputs
775 // from the node.
776 //
777 // If predicates are empty (default initialized), it's assumed that we can enter
778 // into any node and advance from any node respectively.
779 struct DfsTypePredicates {
780   DfsTypePredicates() = default;
DfsTypePredicatestensorflow::grappler::__anon32a4a5c20111::DfsTypePredicates781   DfsTypePredicates(std::function<bool(int)> enter,
782                     std::function<bool(int)> advance)
783       : enter(std::move(enter)), advance(std::move(advance)) {}
784 
Entertensorflow::grappler::__anon32a4a5c20111::DfsTypePredicates785   static DfsTypePredicates Enter(std::function<bool(int)> enter) {
786     return DfsTypePredicates(std::move(enter), nullptr);
787   }
788 
Advancetensorflow::grappler::__anon32a4a5c20111::DfsTypePredicates789   static DfsTypePredicates Advance(std::function<bool(int)> advance) {
790     return DfsTypePredicates(nullptr, std::move(advance));
791   }
792 
793   std::function<bool(int)> enter;
794   std::function<bool(int)> advance;
795 };
796 
797 struct DfsStackElem {
DfsStackElemtensorflow::grappler::__anon32a4a5c20111::DfsStackElem798   DfsStackElem(int node, bool children_visited, int src)
799       : node(node), children_visited(children_visited), src(src) {}
DfsStackElemtensorflow::grappler::__anon32a4a5c20111::DfsStackElem800   explicit DfsStackElem(int node) : DfsStackElem(node, false, -1) {}
801 
802   // Index of the node in the graph ∊ [0, num_nodes).
803   int node;
804   // `True` if visited all the input/output nodes (pushed all input/output nodes
805   // to the stack).
806   bool children_visited;
807   // Index of the node in the graph, from which we entered the `node`.
808   int src;
809 };
810 
811 enum class NodeState { kNotVisited, kVisiting, kDone };
812 
DfsTypeTraversal(const GraphTypeTopologyView & graph_type_view,const absl::Span<const NodeTypeId * const> from,const TypeTraversalDirection direction,const DfsTypePredicates & predicates,const DfsTypeCallbacks & callbacks)813 void DfsTypeTraversal(const GraphTypeTopologyView& graph_type_view,
814                       const absl::Span<const NodeTypeId* const> from,
815                       const TypeTraversalDirection direction,
816                       const DfsTypePredicates& predicates,
817                       const DfsTypeCallbacks& callbacks) {
818   std::vector<DfsStackElem> stack;
819   stack.reserve(from.size());
820 
821   for (const NodeTypeId* node : from) {
822     const absl::optional<int> node_idx = graph_type_view.GetNodeIndex(*node);
823     DCHECK(node_idx.has_value())
824         << "Illegal start node: " << node->node->name();
825     if (node_idx.has_value()) {
826       stack.emplace_back(node_idx.value());
827     }
828   }
829 
830   absl::flat_hash_map<int, NodeState> node_state;
831   while (!stack.empty()) {
832     DfsStackElem w = stack.back();
833     stack.pop_back();
834 
835     NodeState& state = node_state[w.node];
836     if (state == NodeState::kDone) continue;
837 
838     // Skip nodes that we should not enter.
839     if (predicates.enter && !predicates.enter(w.node)) {
840       state = NodeState::kDone;
841       continue;
842     }
843 
844     // We've processed all the children of this node.
845     if (w.children_visited) {
846       state = NodeState::kDone;
847       if (callbacks.post_order) {
848         callbacks.post_order(w.node);
849       }
850       continue;
851     }
852 
853     // Loop detected.
854     if (state == NodeState::kVisiting) {
855       if (callbacks.on_back_edge) {
856         callbacks.on_back_edge(w.src, w.node);
857       }
858       continue;
859     }
860 
861     state = NodeState::kVisiting;
862     if (callbacks.pre_order) {
863       callbacks.pre_order(w.node);
864     }
865 
866     // Enqueue the node again with the children_visited flag set to true.
867     stack.emplace_back(w.node, true, w.src);
868 
869     // Check if we can continue traversal from the current node.
870     if (predicates.advance && !predicates.advance(w.node)) {
871       continue;
872     }
873 
874     // Now enqueue the fanin/fanout nodes.
875     if (direction == TypeTraversalDirection::kFollowInputs ||
876         direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
877       for (const int fanin : graph_type_view.GetFanin(w.node)) {
878         stack.emplace_back(fanin, false, w.node);
879       }
880     }
881     if (direction == TypeTraversalDirection::kFollowOutputs ||
882         direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
883       for (const int fanout : graph_type_view.GetFanout(w.node)) {
884         stack.emplace_back(fanout, false, w.node);
885       }
886     }
887   }
888 }
889 
AllowedDataTypes(const OpDef::AttrDef & attr_def)890 DataTypeSet AllowedDataTypes(const OpDef::AttrDef& attr_def) {
891   const auto& allowed_types = attr_def.allowed_values().list().type();
892   if (allowed_types.empty()) {
893     return AllTypes();
894   }
895   uint32 dtype_mask = 0;
896   for (int dtype : allowed_types) {
897     dtype_mask |= 1u << dtype;
898   }
899   return DataTypeSet(dtype_mask);
900 }
901 
AllowedDataTypes(const OpDef & op_def,const TypeAttrId & t_attr_id)902 DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
903   if (t_attr_id.attr_name.empty()) {
904     return ToSet(t_attr_id.fixed_type);
905   }
906   const OpDef::AttrDef* attr_def = FindAttr(t_attr_id.attr_name, op_def);
907   CHECK(attr_def);  // Crash Ok
908   return AllowedDataTypes(*attr_def);
909 }
910 
ValidateLists(const gtl::FlatSet<string> & allow_list,const gtl::FlatSet<string> & deny_list,const gtl::FlatSet<string> & infer_list,const gtl::FlatSet<string> & clear_list)911 Status ValidateLists(const gtl::FlatSet<string>& allow_list,
912                      const gtl::FlatSet<string>& deny_list,
913                      const gtl::FlatSet<string>& infer_list,
914                      const gtl::FlatSet<string>& clear_list) {
915   std::vector<gtl::FlatSet<string>> lists{allow_list, deny_list, infer_list,
916                                           clear_list};
917   std::multiset<string> counts;
918   for (const auto& list : lists) {
919     counts.insert(list.begin(), list.end());
920   }
921   bool duplicates = false;
922   for (const auto& s : counts) {
923     if (counts.count(s) > 1) {
924       duplicates = true;
925       LOG(ERROR) << "Op present in multiple lists: " << s;
926     }
927   }
928   if (duplicates) {
929     return errors::InvalidArgument("Op lists have conflicting entries");
930   } else {
931     return OkStatus();
932   }
933 }
934 
HasInputOrOutputRefs(const NodeDef & node)935 bool HasInputOrOutputRefs(const NodeDef& node) {
936   const OpDef* op_def;
937   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
938   if (!status.ok()) {
939     return true;
940   }
941   for (const auto& input : op_def->input_arg()) {
942     if (input.is_ref()) {
943       return true;
944     }
945   }
946   for (const auto& output : op_def->output_arg()) {
947     if (output.is_ref()) {
948       return true;
949     }
950   }
951   return false;
952 }
953 
954 // See TF issue 25977 for no-FP16 on SCEWL
CanForceFP16(const NodeDef & node)955 bool CanForceFP16(const NodeDef& node) {
956   return node.op() != "Const" && node.op() != "SoftmaxCrossEntropyWithLogits" &&
957          !IsStateful(node) && !HasInputOrOutputRefs(node);
958 }
959 
GetCudaVersion(const std::unordered_map<string,DeviceProperties> & devices)960 int GetCudaVersion(
961     const std::unordered_map<string, DeviceProperties>& devices) {
962   for (const auto& device : devices) {
963     const DeviceProperties& device_properties = device.second;
964     if (device_properties.type() == "GPU") {
965       const auto& device_env = device_properties.environment();
966       auto it = device_env.find("cuda");
967       if (it != device_env.end()) {
968         string cuda_version_str = it->second;
969         return std::stoi(cuda_version_str);
970       }
971     }
972   }
973   return 0;
974 }
975 
GetCudnnVersion(const std::unordered_map<string,DeviceProperties> & devices)976 int GetCudnnVersion(
977     const std::unordered_map<string, DeviceProperties>& devices) {
978   for (const auto& device : devices) {
979     const DeviceProperties& device_properties = device.second;
980     if (device_properties.type() == "GPU") {
981       const auto& device_env = device_properties.environment();
982       auto it = device_env.find("cudnn");
983       if (it != device_env.end()) {
984         string cudnn_version_str = it->second;
985         return std::stoi(cudnn_version_str);
986       }
987     }
988   }
989   return 0;
990 }
991 
GetDevices(Cluster * cluster)992 std::unordered_map<string, DeviceProperties> GetDevices(Cluster* cluster) {
993   if (!ShouldSimulateGpu()) {
994     return cluster->GetDevices();
995   }
996 
997   bool has_gpu = false;
998   for (const auto& device : cluster->GetDevices()) {
999     const DeviceProperties& device_properties = device.second;
1000     if (device_properties.type() == "GPU") {
1001       has_gpu = true;
1002       break;
1003     }
1004   }
1005 
1006   if (has_gpu) {
1007     return cluster->GetDevices();
1008   }
1009 
1010   std::unordered_map<string, DeviceProperties> devices(cluster->GetDevices());
1011   DeviceProperties gpu_device_properies;
1012   gpu_device_properies.set_type("GPU");
1013 #if GOOGLE_CUDA
1014   gpu_device_properies.set_vendor("NVIDIA");
1015   gpu_device_properies.mutable_environment()->insert({"architecture", "8.0"});
1016   gpu_device_properies.mutable_environment()->insert({"cuda", "11050"});
1017   gpu_device_properies.mutable_environment()->insert({"cudnn", "8302"});
1018 #elif TENSORFLOW_USE_ROCM
1019   gpu_device_properies.set_vendor("Advanced Micro Devices, Inc");
1020   gpu_device_properies.mutable_environment()->insert(
1021       {"architecture", "gfx908"});
1022 #endif
1023   devices.emplace(std::make_pair("/job:localhost/replica:0/task:0/device:GPU:0",
1024                                  gpu_device_properies));
1025   return devices;
1026 }
1027 
1028 class AutoMixedPrecisionImpl {
1029  public:
1030   // CastType indicates the type of inserted Cast op
1031   //   FP16: cast to float16
1032   //   FP32: cast to float32
1033   //   AUTO: cast to a data type that matches the required data type at fanouts
1034   enum class CastType { FP16, FP32, AUTO };
AutoMixedPrecisionImpl(Cluster * cluster,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,string id,AutoMixedPrecisionMode mode)1035   AutoMixedPrecisionImpl(Cluster* cluster,
1036                          const std::unordered_set<string>& nodes_to_preserve,
1037                          GraphDef* graph, string id,
1038                          AutoMixedPrecisionMode mode)
1039       : devices_(GetDevices(cluster)),
1040         virtual_placer_(devices_),
1041         nodes_to_preserve_(nodes_to_preserve),
1042         graph_(graph),
1043         function_library_(OpRegistry::Global(), graph->library()),
1044         id_(id),
1045         graph_view_(graph),
1046         cuda_version_(GetCudaVersion(devices_)),
1047         cudnn_version_(GetCudnnVersion(devices_)),
1048         num_nonvar_casts_to_f16_(0),
1049         mode_(mode),
1050         target_dtype_((mode_ == AutoMixedPrecisionMode::CUDA ||
1051                        mode_ == AutoMixedPrecisionMode::CPU)
1052                           ? DT_HALF
1053                           : DT_BFLOAT16) {}
1054 
1055   Status Optimize();
1056 
1057  private:
1058   typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
1059 
get_mixed_precision_lists() const1060   std::unique_ptr<AutoMixedPrecisionLists> get_mixed_precision_lists() const {
1061     switch (mode_) {
1062       case AutoMixedPrecisionMode::CUDA:
1063         return std::make_unique<AutoMixedPrecisionListsCuda>(cuda_version_,
1064                                                              cudnn_version_);
1065       case AutoMixedPrecisionMode::BF16:
1066         return std::make_unique<AutoMixedPrecisionListsMkl>();
1067       case AutoMixedPrecisionMode::CPU:
1068         // Note: this is not a typo here. AutoMixedPrecisionListsCuda is used
1069         // intentionally to make CPU and GPU have the same fp16 ops.
1070         return std::make_unique<AutoMixedPrecisionListsCuda>(
1071             /*cuda_version=*/10000,   // Hardcode cuda and cudnn version so
1072             /*cudnn_version=*/8000);  // CPU emulates the same ops on GPU.
1073     }
1074   }
1075   Status PrintDebugLogs(bool preop, size_t timestamp);
1076   void LogSkippedNode(const NodeDef& node, const string& device_type) const;
1077   bool MustPreserve(const NodeDef& node) const;
1078   bool IsOnDevice(const NodeDef& node, const string& device_type) const;
1079   bool IsOnSuitableGPUArch(const NodeDef& node) const;
1080   bool ShouldProcess(const NodeDef& node) const;
1081   bool NodeHasF16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
1082   bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
1083   void ConvertBatchNormOpsToV2();
1084   bool SupportsF16(const NodeTypeId& node_type) const;
1085   bool SupportsF16DataType(const NodeTypeId& node_type) const;
1086   bool IsQuantized(const NodeTypeId& node_type) const;
1087   const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
1088   bool IsSourceOrSinkOp(const string& op) const;
1089   void FindFloat32TensorListOpClustersAndDenylistUnsafe(
1090       std::vector<absl::flat_hash_set<const NodeDef*>>* clusters,
1091       absl::flat_hash_set<int>* deny_set) const;
1092   void FindTensorListImplicitFloat32Edges(
1093       const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1094       std::vector<NodeTypeIdEdge>* implicit_fp32_edges) const;
1095   void AddAllowlistOps(absl::flat_hash_set<int>* allow_set) const;
1096   void RemoveAllowsetWithFp32(absl::flat_hash_set<int>* allow_set) const;
1097   void PropagateDenyFwdThroughClearAndInfer(
1098       absl::flat_hash_set<int>* deny_set) const;
1099   void ForceColorMatchBetweenTensorListOps(
1100       const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1101       absl::flat_hash_set<int>* allow_set,
1102       absl::flat_hash_set<int>* deny_set) const;
1103   void AddClearAndInferToAllowIfBetweenAllow(
1104       const absl::flat_hash_set<int>& deny_set,
1105       absl::flat_hash_set<int>* allow_set) const;
1106   void AddInferToAllowIfFollowAllow(const absl::flat_hash_set<int>& deny_set,
1107                                     absl::flat_hash_set<int>* allow_set) const;
1108   void PropagateAllowThroughClear(const absl::flat_hash_set<int>& deny_set,
1109                                   absl::flat_hash_set<int>* allow_set) const;
1110   Status ForceColorMatchOnRecurrentEdges(
1111       absl::flat_hash_set<int>* allow_set) const;
1112   void MakeCastsAllowIfAllOutputsAllow(
1113       absl::flat_hash_set<int>* allow_set) const;
1114   NodeDef BuildCastNode(const MutableGraphView::OutputPort& src,
1115                         const MutableGraphView::InputPort& dst, bool to_f16,
1116                         const string& device) const;
1117   StatusOr<NodeDef*> InsertCastNodeAtFanout(
1118       const absl::flat_hash_set<int>& allow_set, const bool src_is_allow,
1119       const CastType& cast_type, MutableGraphView::OutputPort& src);
1120 
1121   StatusOr<DataType> GetCastToType(const NodeDef* node) const;
1122   void CollectOutputPorts(
1123       const TypeAttrId& type_attr, NodeDef* node,
1124       std::vector<MutableGraphView::OutputPort>& output_ports) const;
1125   Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& allow_set);
1126 
1127   std::unordered_map<string, DeviceProperties> devices_;
1128   VirtualPlacer virtual_placer_;
1129   std::unordered_set<string> nodes_to_preserve_;
1130   GraphDef* graph_;
1131   FunctionLibraryDefinition function_library_;
1132   string id_;
1133   MutableGraphView graph_view_;
1134   int cuda_version_;
1135   int cudnn_version_;
1136   int num_nonvar_casts_to_f16_;
1137   NodeTypeAttrMap node_type_map_;
1138   GraphTypeTopologyView graph_type_view_;
1139   bool force_all_fp16_;
1140   bool treat_infer_as_deny_;
1141   AutoMixedPrecisionMode mode_;
1142   gtl::FlatSet<string> f16_allowlist_;
1143   gtl::FlatSet<string> f16_denylist_;
1144   gtl::FlatSet<string> f16_inferlist_;
1145   gtl::FlatSet<string> f16_clearlist_;
1146   absl::flat_hash_set<const NodeDef*> should_process_nodes_;
1147   DataType target_dtype_;  // Either DT_HALF or DT_BFLOAT16
1148 };
1149 
BuildCastNode(const MutableGraphView::OutputPort & src,const MutableGraphView::InputPort & dst,bool to_f16,const string & device) const1150 NodeDef AutoMixedPrecisionImpl::BuildCastNode(
1151     const MutableGraphView::OutputPort& src,
1152     const MutableGraphView::InputPort& dst, bool to_f16,
1153     const string& device) const {
1154   DataType src_type = to_f16 ? DT_FLOAT : target_dtype_;
1155   DataType dst_type = to_f16 ? target_dtype_ : DT_FLOAT;
1156   const char* cast_string = !to_f16                    ? kCastToFp32
1157                             : target_dtype_ == DT_HALF ? kCastToFp16
1158                                                        : kCastToBf16;
1159   string name =
1160       strings::StrCat(src.node->name(), "-", src.port_id, "-", dst.node->name(),
1161                       "-", dst.port_id, "-", cast_string, "-", kSuffix);
1162   NodeDef node;
1163   node.set_name(name);
1164   node.set_op("Cast");
1165   node.set_device(device);
1166   node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
1167   (*node.mutable_attr())["SrcT"].set_type(src_type);
1168   (*node.mutable_attr())["DstT"].set_type(dst_type);
1169   (*node.mutable_attr())["Truncate"].set_b(false);
1170   return node;
1171 }
1172 
NodeHasF16KernelForTypeAttr(const NodeDef & node,TypeAttrId taid) const1173 bool AutoMixedPrecisionImpl::NodeHasF16KernelForTypeAttr(
1174     const NodeDef& node, TypeAttrId taid) const {
1175   NodeDef node_copy(node);
1176   if (node.device().empty()) {
1177     string device_name = virtual_placer_.get_canonical_device_name(node);
1178     node_copy.set_device(device_name);
1179   }
1180   if (!SetDataType(&node_copy, taid, target_dtype_)) {
1181     return false;
1182   }
1183   return IsKernelRegisteredForNode(node_copy).ok();
1184 }
1185 
PrintDebugLogs(bool preop,size_t timestamp)1186 Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
1187   string prepend_path;
1188   TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1189       "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LOG_PATH", "", &prepend_path));
1190   if (prepend_path.empty()) return OkStatus();
1191 
1192   string suffix =
1193       strings::StrCat("_", preop ? "preop" : kSuffix, "_", id_, "_", timestamp);
1194 
1195   string fname =
1196       io::JoinPath(prepend_path, strings::StrCat("graphdef", suffix, ".pb"));
1197   std::fstream f;
1198   f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
1199   f << graph_->SerializeAsString();
1200   f.close();
1201   LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1202             << " graph as binary to " << fname;
1203 
1204   fname = io::JoinPath(prepend_path,
1205                        strings::StrCat("graphdef", suffix, ".pb.txt"));
1206   f.open(fname.c_str(), std::fstream::out);
1207   f << graph_->DebugString();
1208   f.close();
1209   LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1210             << " graph as text to " << fname;
1211 
1212   if (!preop) {
1213     fname = io::JoinPath(prepend_path,
1214                          strings::StrCat("paintbuckets", suffix, ".txt"));
1215     f.open(fname.c_str(), std::fstream::out);
1216     std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1217         get_mixed_precision_lists();
1218     f << "AllowList:\n";
1219     for (const auto& x : mp_lists->AllowList()) {
1220       f << x << "\n";
1221     }
1222     f << "\nDenyList:\n";
1223     for (const auto& x : mp_lists->DenyList()) {
1224       f << x << "\n";
1225     }
1226     f << "\nInferList:\n";
1227     for (const auto& x : mp_lists->InferList()) {
1228       f << x << "\n";
1229     }
1230     f << "\nClearList:\n";
1231     for (const auto& x : mp_lists->ClearList()) {
1232       f << x << "\n";
1233     }
1234     f.close();
1235     LOG(INFO) << "Saved paint bucket info to " << fname;
1236   }
1237   return OkStatus();
1238 }
1239 
LogSkippedNode(const NodeDef & node,const string & device_type) const1240 void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node,
1241                                             const string& device_type) const {
1242   VLOG(2) << "Skipping " << node.op() << " node " << node.name()
1243           << " because it "
1244           << (MustPreserve(node)
1245                   ? "must be preserved"
1246                   : absl::StrFormat(
1247                         "is not on the %s, or the %s arch is not suitable",
1248                         device_type, device_type));
1249 }
1250 
MustPreserve(const NodeDef & node) const1251 bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
1252   return nodes_to_preserve_.count(node.name());
1253 }
1254 
IsOnDevice(const NodeDef & node,const string & device_type) const1255 bool AutoMixedPrecisionImpl::IsOnDevice(const NodeDef& node,
1256                                         const string& device_type) const {
1257   string device_name;
1258   if (node.device().empty()) {
1259     device_name = virtual_placer_.get_canonical_device_name(node);
1260   } else {
1261     device_name = node.device();
1262   }
1263   string device;
1264   string not_used;
1265   if (DeviceNameUtils::SplitDeviceName(device_name, &not_used, &device) &&
1266       absl::StrContains(absl::AsciiStrToLower(device),
1267                         absl::AsciiStrToLower(device_type))) {
1268     return true;
1269   }
1270   return false;
1271 }
1272 
IsOnSuitableGPUArch(const NodeDef & node) const1273 bool AutoMixedPrecisionImpl::IsOnSuitableGPUArch(const NodeDef& node) const {
1274   return HasFastFP16Support(virtual_placer_.get_device(node));
1275 }
1276 
ShouldProcess(const NodeDef & node) const1277 bool AutoMixedPrecisionImpl::ShouldProcess(const NodeDef& node) const {
1278   return should_process_nodes_.count(&node);
1279 }
1280 
IsFloat32(const NodeTypeId & node_type)1281 bool IsFloat32(const NodeTypeId& node_type) {
1282   return GetDataType(*node_type.node, node_type.type_attr) ==
1283          DataType::DT_FLOAT;
1284 }
1285 
IsTensorListOp(const string & op)1286 bool IsTensorListOp(const string& op) {
1287   return absl::StrContains(op, "TensorList");
1288 }
1289 
IsTensorListReaderOp(const string & op)1290 bool IsTensorListReaderOp(const string& op) {
1291   static const gtl::FlatSet<string> tensor_list_reader_ops = {
1292       "TensorListConcat",  "TensorListConcatV2", "TensorListGather",
1293       "TensorListGetItem", "TensorListPopBack",  "TensorListStack"};
1294   return tensor_list_reader_ops.count(op);
1295 }
1296 
IsTensorListWriterOp(const string & op)1297 bool IsTensorListWriterOp(const string& op) {
1298   static const gtl::FlatSet<string> tensor_list_writer_ops = {
1299       "TensorListFromTensor",    "TensorListPushBack",
1300       "TensorListPushBackBatch", "TensorListScatter",
1301       "TensorListScatterV2",     "TensorListScatterIntoExistingList",
1302       "TensorListSetItem",       "TensorListSplit"};
1303   return tensor_list_writer_ops.count(op);
1304 }
1305 
SupportsF16(const NodeTypeId & node_type) const1306 bool AutoMixedPrecisionImpl::SupportsF16(const NodeTypeId& node_type) const {
1307   const OpDef* op_def;
1308   Status status =
1309       OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1310   if (!status.ok()) return false;
1311   return AllowedDataTypes(*op_def, node_type.type_attr)
1312              .Contains(target_dtype_) &&
1313          NodeHasF16KernelForTypeAttr(*node_type.node, node_type.type_attr);
1314 }
1315 
SupportsF16DataType(const NodeTypeId & node_type) const1316 bool AutoMixedPrecisionImpl::SupportsF16DataType(
1317     const NodeTypeId& node_type) const {
1318   const OpDef* op_def;
1319   Status status =
1320       OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1321   if (!status.ok()) return false;
1322   return AllowedDataTypes(*op_def, node_type.type_attr).Contains(target_dtype_);
1323 }
1324 
IsQuantized(const NodeTypeId & node_type) const1325 bool AutoMixedPrecisionImpl::IsQuantized(const NodeTypeId& node_type) const {
1326   for (const TypeAttrId& type_attr :
1327        node_type_map_.GetTypeAttrs(*node_type.node)) {
1328     if (DataTypeIsQuantized(GetDataType(*node_type.node, type_attr))) {
1329       return true;
1330     }
1331   }
1332   return false;
1333 }
1334 
1335 // TODO(mconley): Make this change the node's name (to aid debugging). Need to
1336 // make sure that doing this won't break anything.
ConvertBatchNormOpsToV2()1337 void AutoMixedPrecisionImpl::ConvertBatchNormOpsToV2() {
1338   for (int node_idx = 0; node_idx < graph_->node_size(); ++node_idx) {
1339     NodeDef* node = graph_->mutable_node(node_idx);
1340     if (!ShouldProcess(*node)) continue;
1341     bool changed = false;
1342     if (node->op() == "FusedBatchNorm") {
1343       VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1344               << " to FusedBatchNormV2";
1345       node->set_op("FusedBatchNormV2");
1346       changed = true;
1347     } else if (node->op() == "FusedBatchNormGrad") {
1348       VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1349               << " to FusedBatchNormGradV2";
1350       node->set_op("FusedBatchNormGradV2");
1351       changed = true;
1352     }
1353     if (changed) {
1354       (*node->mutable_attr())["U"].set_type(DT_FLOAT);
1355     }
1356   }
1357 }
1358 
1359 // A helper function to decide whether to ignore the effect on performance when
1360 // rewriting the graph. This can be useful for testing the numerical effects of
1361 // reduced precision on systems that have poor mixed precision performance.
ShouldIgnorePerformance()1362 bool ShouldIgnorePerformance() {
1363   static bool is_enabled = [] {
1364     bool ret = false;
1365     TF_CHECK_OK(ReadBoolFromEnvVar(
1366         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE",
1367         /*default_val=*/false, &ret));
1368     return ret;
1369   }();
1370   return is_enabled;
1371 }
1372 
Optimize()1373 Status AutoMixedPrecisionImpl::Optimize() {
1374   string optimization_level;
1375   TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1376       "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
1377   optimization_level = absl::AsciiStrToUpper(optimization_level);
1378   force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
1379   if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::BF16) {
1380     // Many ops do not support bfloat16 on the CPU so we disallowing forcing to
1381     // bfloat16.
1382     return errors::InvalidArgument(
1383         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to "
1384         "UNSAFE_FORCE_ALL when oneDNN is used");
1385   }
1386 
1387   treat_infer_as_deny_ = optimization_level == "TREAT_INFER_AS_DENY";
1388   VLOG(2) << "Optimization Level: " << optimization_level;
1389 
1390   std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1391       get_mixed_precision_lists();
1392   f16_allowlist_ = mp_lists->AllowList();
1393   f16_denylist_ = mp_lists->DenyList();
1394 
1395   if (treat_infer_as_deny_) {
1396     for (const auto& op : mp_lists->InferList()) {
1397       f16_denylist_.insert(op);
1398     }
1399   } else {
1400     f16_inferlist_ = mp_lists->InferList();
1401   }
1402 
1403   f16_clearlist_ = mp_lists->ClearList();
1404   TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_denylist_,
1405                                    f16_inferlist_, f16_clearlist_));
1406 
1407   size_t timestamp = Env::Default()->NowMicros() / 1000;
1408   TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
1409 
1410   VLOG(2) << "Identifying nodes that should be processed";
1411   for (const NodeDef& node : graph_->node()) {
1412     bool should_process;
1413     string device_type;
1414     switch (mode_) {
1415       case AutoMixedPrecisionMode::CUDA:
1416         device_type = DEVICE_GPU;
1417         should_process =
1418             !MustPreserve(node) && IsOnDevice(node, device_type) &&
1419             (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
1420         break;
1421       case AutoMixedPrecisionMode::BF16:
1422       case AutoMixedPrecisionMode::CPU:
1423         device_type = DEVICE_CPU;
1424         should_process = !MustPreserve(node) && IsOnDevice(node, device_type);
1425         break;
1426     }
1427     if (should_process) {
1428       should_process_nodes_.insert(&node);
1429     } else {
1430       LogSkippedNode(node, device_type);
1431     }
1432   }
1433 
1434   VLOG(2) << "Converting FusedBatchNorm* ops to V2";
1435   ConvertBatchNormOpsToV2();
1436 
1437   VLOG(2) << "Building node type map for graph";
1438   TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_));
1439 
1440   VLOG(2) << "Constructing graph type attribute topology view";
1441   TF_RETURN_IF_ERROR(
1442       graph_type_view_.InitializeFromGraph(*graph_, node_type_map_));
1443 
1444   absl::flat_hash_set<int> deny_set;
1445 
1446   std::vector<absl::flat_hash_set<const NodeDef*>> tensor_list_clusters;
1447   FindFloat32TensorListOpClustersAndDenylistUnsafe(&tensor_list_clusters,
1448                                                    &deny_set);
1449   std::vector<NodeTypeIdEdge> ephemeral_edges;
1450   for (const auto& cluster : tensor_list_clusters) {
1451     VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
1452     for (const NodeDef* node : cluster) {
1453       VLOG(2) << "  Cluster member: " << node->op() << " node " << node->name();
1454     }
1455     FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
1456   }
1457   TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));
1458 
1459   // The goal here is to change performance-critical ops to fp16 or bf16, and to
1460   // do so with the minimal number of casts, subject to the constraint that the
1461   // model's convergence is not affected. This is achieved by first identifying
1462   // which nodes should be changed to f16 and then inserting casts at the
1463   // boundaries between f16/non-f16 nodes.
1464 
1465   // The algorithm for deciding which nodes to change to f16 is as follows:
1466   // 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set.
1467   //    This is done under the assumption that allowlist ops are always
1468   //    numerically-safe in f16 and that they are the most important ops for
1469   //    improving performance.
1470   // 2) Add nodes to the deny_set iff they are numerically-dangerous (aka
1471   //    "denylist" ops) or they are on a forward path from a denylist node to
1472   //    a deny/infer node (including the node at the end of the path) through
1473   //    non-numerically-dangerous ops (aka "inferlist" and "clearlist" ops).
1474   //    This is done to prevent numerically-dangerous ops and their downstream
1475   //    effects from being changed to f16, which would risk breaking the
1476   //    numerical accuracy of the model.
1477   // 3) For all remaining nodes that are not considered dangerous (inferlist
1478   //    and clearlist ops), find those that are between (i.e., both upstream
1479   //    and downstream of) allow nodes, and add them to the allow_set.
1480   //    This is done to avoid unnecessary casts between allowlist ops.
1481   // 4) For the remaining inferlist nodes, add them to the allow_set if they
1482   //    are immediate downstream of allow_set node.
1483   // 5) For all remaining clearlist nodes, add them to the allow_set if they are
1484   //    connected to a node in the allow_set via other clearlist nodes.
1485   //    This is done to increase the number of ops in the allow_set without
1486   //    affecting numerical stability.
1487 
1488   absl::flat_hash_set<int> allow_set;
1489   VLOG(2) << "Beginning pass 1 to add allowlist ops";
1490   AddAllowlistOps(&allow_set);
1491   VLOG(2) << "Finished pass 1";
1492 
1493   if (allow_set.empty()) {
1494     LOG(INFO) << "No allowlist ops found, nothing to do";
1495     return OkStatus();
1496   }
1497 
1498   VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops "
1499              "through clear/inferlist ops";
1500   PropagateDenyFwdThroughClearAndInfer(&deny_set);
1501   VLOG(2) << "Finished pass 2";
1502 
1503   VLOG(2) << "Forcing color match between data structure ops";
1504   for (const auto& cluster : tensor_list_clusters) {
1505     ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1506   }
1507 
1508   VLOG(2) << "Beginning pass 3 to set clear and infer nodes to allow if they "
1509              "are between allow ops";
1510   AddClearAndInferToAllowIfBetweenAllow(deny_set, &allow_set);
1511   VLOG(2) << "Finished pass 3";
1512 
1513   VLOG(2) << "Beginning pass 4 to add infer list ops to allow if they "
1514              "directly follow allow nodes";
1515   AddInferToAllowIfFollowAllow(deny_set, &allow_set);
1516   VLOG(2) << "Finished pass 4";
1517 
1518   VLOG(2) << "Beginning pass 5 to propagate allow from allow nodes through "
1519              "clearlist ops";
1520   PropagateAllowThroughClear(deny_set, &allow_set);
1521   VLOG(2) << "Finished pass 5";
1522 
1523   VLOG(2) << "Beginning pass 6 to remove some nodes which could not be changed "
1524              "to F16"
1525              "from allow set";
1526   RemoveAllowsetWithFp32(&allow_set);
1527   VLOG(2) << "Finished pass 6";
1528 
1529   VLOG(2) << "Forcing color match between data structure ops";
1530   for (const auto& cluster : tensor_list_clusters) {
1531     ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1532   }
1533 
1534   VLOG(2) << "Forcing color match on loop edges";
1535   TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set));
1536 
1537   VLOG(2) << "Finding existing casts that can be made allow";
1538   MakeCastsAllowIfAllOutputsAllow(&allow_set);
1539 
1540   VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
1541              "ops at paint boundaries";
1542   TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set));
1543   VLOG(2) << "Finished final pass";
1544 
1545   TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
1546 
1547   return OkStatus();
1548 }
1549 
1550 // If node is a Tensor List op with a float32 data type attribute then this
1551 // returns a pointer to the NodeTypeId representing that type attribute. In
1552 // all other cases this returns nullptr.
GetTensorListFloat32NodeTypeId(const NodeDef & node) const1553 const NodeTypeId* AutoMixedPrecisionImpl::GetTensorListFloat32NodeTypeId(
1554     const NodeDef& node) const {
1555   if (!IsTensorListOp(node.op())) return nullptr;
1556   for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(node)) {
1557     const NodeTypeId* node_type =
1558         graph_type_view_.GetNode(node.name(), type_attr);
1559     // This assumes that the float32 data type on a Tensor List op is always a
1560     // non-fixed type attribute containing a single type, and that this type
1561     // attribute represents the dtype of the values in the list.
1562     // TODO(benbarsdell): A new Tensor List op could theoretically break these
1563     // assumptions.
1564     if (node_type && node_type->type_attr.fixed_type == DT_INVALID &&
1565         node_type->type_attr.type_index == TypeAttrId::kSingleType &&
1566         IsFloat32(*node_type)) {
1567       return node_type;
1568     }
1569   }
1570   return nullptr;
1571 }
1572 
IsSourceOrSinkOp(const string & op) const1573 bool AutoMixedPrecisionImpl::IsSourceOrSinkOp(const string& op) const {
1574   const gtl::FlatSet<string> source_and_sink_ops = {
1575       "_Arg",
1576       "_Retval",
1577       "OptionalFromValue",
1578       "OptionalGetValue",
1579       "PartitionedCall",
1580       "Placeholder",
1581       "StatefulPartitionedCall",
1582   };
1583   return source_and_sink_ops.count(op) || function_library_.Find(op);
1584 }
1585 
1586 // Finds all clusters of float32 Tensor List nodes that are connected via their
1587 // handle edges. Unsafe clusters (those with unprocessable nodes, or with edges
1588 // that cross untraversable boundaries via _Arg, _Ret, PartitionedCall etc.
1589 // nodes) are added to deny_set. The caller should paint all nodes in a cluster
1590 // the same color, as they may all refer to the same Tensor List.
FindFloat32TensorListOpClustersAndDenylistUnsafe(std::vector<absl::flat_hash_set<const NodeDef * >> * tensor_list_clusters,absl::flat_hash_set<int> * deny_set) const1591 void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndDenylistUnsafe(
1592     std::vector<absl::flat_hash_set<const NodeDef*>>* tensor_list_clusters,
1593     absl::flat_hash_set<int>* deny_set) const {
1594   absl::flat_hash_set<const NodeDef*> tensor_list_prop_set;
1595   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1596     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1597     if (!ShouldProcess(*root.node) ||
1598         root.type_attr.fixed_type != DataType::DT_VARIANT ||
1599         !GetTensorListFloat32NodeTypeId(*root.node) ||
1600         tensor_list_prop_set.count(root.node)) {
1601       continue;
1602     }
1603     const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1604     const absl::optional<int> maybe_root_fp32_idx =
1605         graph_type_view_.GetNodeIndex(*root_fp32);
1606     DCHECK(maybe_root_fp32_idx.has_value())
1607         << "Type attribute " << root_fp32->type_attr.DebugString()
1608         << " of node " << root.node->name() << " not found in graph view";
1609     int root_fp32_idx = maybe_root_fp32_idx.value();
1610     // Traverse Tensor List handle edges (DT_VARIANT) to find cluster of all
1611     // connected Tensor List nodes.
1612     absl::flat_hash_set<const NodeDef*> cluster({root.node});
1613     DfsTypeTraversal(graph_type_view_, {&root},
1614                      TypeTraversalDirection::kFollowInputsAndOutputs,
1615                      DfsTypePredicates::Enter([&](int idx) -> bool {
1616                        const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1617                        return !tensor_list_prop_set.count(item.node);
1618                      }),
1619                      DfsTypeCallbacks::PreOrder([&](int idx) {
1620                        const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1621                        const NodeDef* node = item.node;
1622                        if (GetTensorListFloat32NodeTypeId(*node)) {
1623                          cluster.insert(node);
1624                          if (!ShouldProcess(*node)) {
1625                            // The cluster contains an un-processable node.
1626                            deny_set->insert(root_fp32_idx);
1627                          }
1628                          // TODO(benbarsdell): In a theoretical pathological
1629                          // case of a Tensor List of Tensor List handles, the
1630                          // Tensor List itself would need to be treated as a
1631                          // sink.
1632                        } else if (IsSourceOrSinkOp(node->op())) {
1633                          // The cluster crosses an untraversable boundary.
1634                          deny_set->insert(root_fp32_idx);
1635                        }
1636                      }));
1637     tensor_list_clusters->push_back(cluster);
1638   }
1639 }
1640 
1641 // Finds all writer -> reader pairs in the given set that are connected via
1642 // their handles, and adds corresponding float32 edges to *implicit_fp32_edges.
FindTensorListImplicitFloat32Edges(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,std::vector<NodeTypeIdEdge> * implicit_fp32_edges) const1643 void AutoMixedPrecisionImpl::FindTensorListImplicitFloat32Edges(
1644     const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1645     std::vector<NodeTypeIdEdge>* implicit_fp32_edges) const {
1646   for (const NodeDef* root_node : tensor_list_nodes) {
1647     if (!IsTensorListReaderOp(root_node->op())) continue;
1648     NodeTypeId root(root_node, TypeAttrId(DataType::DT_VARIANT));
1649     const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1650     CHECK(root_fp32) << "No float32 type attribute found for "  // Crash OK
1651                      << root.node->op() << " node " << root.node->name();
1652     // Search backwards through handle edges (DT_VARIANT) for all writer ops,
1653     // adding direct implicit edges between them and the reader.
1654     DfsTypeTraversal(
1655         graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1656         DfsTypePredicates::Enter([&](int idx) -> bool {
1657           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1658           return ShouldProcess(*item.node);
1659         }),
1660         DfsTypeCallbacks::PreOrder([&](int idx) {
1661           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1662           if (IsTensorListWriterOp(item.node->op())) {
1663             const NodeTypeId* item_fp32 =
1664                 GetTensorListFloat32NodeTypeId(*item.node);
1665             CHECK(item_fp32)  // Crash OK
1666                 << "No float32 type attribute found for " << item.node->op()
1667                 << " node " << item.node->name();
1668             VLOG(2) << "Adding ephemeral float32 edge from "
1669                     << item_fp32->node->op() << " node "
1670                     << item_fp32->node->name() << " to "
1671                     << root_fp32->node->op() << " node "
1672                     << root_fp32->node->name();
1673             implicit_fp32_edges->emplace_back(*item_fp32, *root_fp32);
1674           }
1675         }));
1676   }
1677 }
1678 
AddAllowlistOps(absl::flat_hash_set<int> * allow_set) const1679 void AutoMixedPrecisionImpl::AddAllowlistOps(
1680     absl::flat_hash_set<int>* allow_set) const {
1681   // Add allowlisted ops to allow_set.
1682   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1683     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1684     if (!ShouldProcess(*root.node)) continue;
1685     bool force_allow = force_all_fp16_ && CanForceFP16(*root.node);
1686     if (f16_allowlist_.count(root.node->op()) || force_allow) {
1687       bool inserted = allow_set->insert(root_idx).second;
1688       if (VLOG_IS_ON(2) && inserted) {
1689         VLOG(2) << "Painting type " << root.type_attr.DebugString()
1690                 << " of node " << root.node->name() << " ALLOW because its op "
1691                 << root.node->op() << " is on the allowlist";
1692       }
1693     }
1694   }
1695 }
1696 
1697 // Adds nodes to deny_set iff they are on the denylist or they are on a
1698 // forward path from a denylist node to a deny/infer node (including the node
1699 // at the end of the path) through clear and infer nodes.
1700 // E.g., deny -> infer -> clear -> infer -> clear -> allow -> infer
1701 // becomes: deny -> deny -> deny -> deny -> clear -> allow -> infer.
PropagateDenyFwdThroughClearAndInfer(absl::flat_hash_set<int> * deny_set) const1702 void AutoMixedPrecisionImpl::PropagateDenyFwdThroughClearAndInfer(
1703     absl::flat_hash_set<int>* deny_set) const {
1704   if (force_all_fp16_) return;
1705 
1706   // Find clear nodes that are upstream of deny or infer.
1707   absl::flat_hash_set<int> upstream_of_deny_or_infer_set;
1708   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1709     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1710     if (!(f16_denylist_.count(root.node->op()) ||
1711           f16_inferlist_.count(root.node->op()))) {
1712       continue;
1713     }
1714     DfsTypeTraversal(graph_type_view_, {&root},
1715                      TypeTraversalDirection::kFollowInputs,
1716                      DfsTypePredicates::Enter([&](int idx) -> bool {
1717                        const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1718                        return idx == root_idx ||
1719                               (!upstream_of_deny_or_infer_set.count(idx) &&
1720                                f16_clearlist_.count(item.node->op()));
1721                      }),
1722                      DfsTypeCallbacks::PreOrder([&](int idx) {
1723                        upstream_of_deny_or_infer_set.insert(idx);
1724                      }));
1725   }
1726 
1727   // Propagate deny forward through nodes in upstream_of_deny_or_infer_set.
1728   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1729     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1730     if (deny_set->count(root_idx) || !f16_denylist_.count(root.node->op())) {
1731       continue;
1732     }
1733     DfsTypeTraversal(
1734         graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1735         DfsTypePredicates::Enter([&](int idx) -> bool {
1736           return idx == root_idx || (!deny_set->count(idx) &&
1737                                      upstream_of_deny_or_infer_set.count(idx));
1738         }),
1739         DfsTypeCallbacks::PreOrder([&](int idx) {
1740           bool inserted = deny_set->insert(idx).second;
1741           if (VLOG_IS_ON(2) && inserted) {
1742             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1743             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1744                     << " of " << item.node->op() << " node "
1745                     << item.node->name() << " DENY";
1746           }
1747         }));
1748   }
1749 }
1750 
AddClearAndInferToAllowIfBetweenAllow(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1751 void AutoMixedPrecisionImpl::AddClearAndInferToAllowIfBetweenAllow(
1752     const absl::flat_hash_set<int>& deny_set,
1753     absl::flat_hash_set<int>* allow_set) const {
1754   // Find clear/inferlist ops that are downstream of allow ops.
1755   absl::flat_hash_set<int> downstream_of_allow_set;
1756   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1757     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1758     if (!ShouldProcess(*root.node) || !f16_allowlist_.count(root.node->op())) {
1759       continue;
1760     }
1761     DfsTypeTraversal(
1762         graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1763         DfsTypePredicates::Enter([&](int idx) -> bool {
1764           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1765           return idx == root_idx ||
1766                  (!downstream_of_allow_set.count(idx) &&
1767                   !f16_allowlist_.count(item.node->op()) &&
1768                   !deny_set.count(idx) && ShouldProcess(*item.node) &&
1769                   // TODO(benbarsdell): Consider allowing propagation through
1770                   // ops that are already float16 in order to reduce the number
1771                   // of casts.
1772                   IsFloat32(item) && SupportsF16(item) &&
1773                   (f16_clearlist_.count(item.node->op()) ||
1774                    f16_inferlist_.count(item.node->op())));
1775         }),
1776         DfsTypeCallbacks::PreOrder(
1777             [&](int idx) { downstream_of_allow_set.insert(idx); }));
1778   }
1779 
1780   // Set nodes that are both downstream and upstream of allow ops to allow.
1781   absl::flat_hash_set<int> upstream_of_allow_set;
1782   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1783     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1784     if (!ShouldProcess(*root.node) || upstream_of_allow_set.count(root_idx) ||
1785         !f16_allowlist_.count(root.node->op())) {
1786       continue;
1787     }
1788     DfsTypeTraversal(
1789         graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1790         DfsTypePredicates::Enter([&](int idx) -> bool {
1791           return idx == root_idx || (!upstream_of_allow_set.count(idx) &&
1792                                      downstream_of_allow_set.count(idx));
1793         }),
1794         DfsTypeCallbacks::PreOrder([&](int idx) {
1795           upstream_of_allow_set.insert(idx);
1796           bool inserted = allow_set->insert(idx).second;
1797           if (VLOG_IS_ON(2) && inserted) {
1798             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1799             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1800                     << " of " << item.node->op() << " node "
1801                     << item.node->name() << " ALLOW";
1802           }
1803         }));
1804   }
1805 }
1806 
PropagateAllowThroughClear(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1807 void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
1808     const absl::flat_hash_set<int>& deny_set,
1809     absl::flat_hash_set<int>* allow_set) const {
1810   // Propagate allow from allow nodes through clearlist ops.
1811   absl::flat_hash_set<int> clear_prop_set;
1812   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1813     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1814     if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) ||
1815         !allow_set->count(root_idx)) {
1816       continue;
1817     }
1818     DfsTypeTraversal(
1819         graph_type_view_, {&root},
1820         TypeTraversalDirection::kFollowInputsAndOutputs,
1821         DfsTypePredicates::Enter([&](int idx) -> bool {
1822           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1823           return idx == root_idx ||
1824                  (!allow_set->count(idx) && !deny_set.count(idx) &&
1825                   ShouldProcess(*item.node) && IsFloat32(item) &&
1826                   SupportsF16(item) &&
1827                   (f16_clearlist_.count(item.node->op())) &&
1828                   // We don't propagate (backwards) through nodes that read
1829                   // Variables because it can break the behavior of TensorBoard
1830                   // visualization and/or (in the case of Enter nodes) the model
1831                   // itself. This is only a problem for non-resource variables.
1832                   !NodeImplicitlyReadsNonResourceVariable(*item.node));
1833         }),
1834         DfsTypeCallbacks::PreOrder([&](int idx) {
1835           clear_prop_set.insert(idx);
1836           bool inserted = allow_set->insert(idx).second;
1837           if (VLOG_IS_ON(2) && inserted) {
1838             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1839             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1840                     << " of " << item.node->op() << " node "
1841                     << item.node->name() << " ALLOW";
1842           }
1843         }));
1844   }
1845 }
1846 
1847 // Set infer node to allow if its immediate upstream node is in allow set
AddInferToAllowIfFollowAllow(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1848 void AutoMixedPrecisionImpl::AddInferToAllowIfFollowAllow(
1849     const absl::flat_hash_set<int>& deny_set,
1850     absl::flat_hash_set<int>* allow_set) const {
1851   // Currently only target for oneDNN
1852   if (mode_ != AutoMixedPrecisionMode::BF16) {
1853     return;
1854   }
1855   for (int item_idx = 0; item_idx < graph_type_view_.num_nodes(); ++item_idx) {
1856     const NodeTypeId& item = *graph_type_view_.GetNode(item_idx);
1857     if (!ShouldProcess(*item.node) || deny_set.count(item_idx) ||
1858         allow_set->count(item_idx) || !f16_inferlist_.count(item.node->op()) ||
1859         !IsFloat32(item) || !SupportsF16DataType(item)) {
1860       continue;
1861     }
1862 
1863     bool has_allow_fanin = false;
1864     for (const int fanin : graph_type_view_.GetFanin(item_idx)) {
1865       if (deny_set.count(fanin)) {
1866         has_allow_fanin = false;
1867         break;
1868       }
1869       if (allow_set->count(fanin)) {
1870         has_allow_fanin = true;
1871       }
1872     }
1873     if (has_allow_fanin) {
1874       bool inserted = allow_set->insert(item_idx).second;
1875       if (VLOG_IS_ON(2) && inserted) {
1876         VLOG(2) << "Painting type " << item.type_attr.DebugString() << " of "
1877                 << item.node->op() << " node " << item.node->name() << " ALLOW";
1878       }
1879     }
1880   }
1881 }
1882 
1883 // If ops have one or more type_attr, But this type_attr could not be converted
1884 // to F16. Such as FusedBatchNormV2/FusedBatchNormV3, its type_attr 'U' only
1885 // support float. So we will remove this node from allow_set.
1886 // Also don't convert quantized ops to FP16.
RemoveAllowsetWithFp32(absl::flat_hash_set<int> * allow_set) const1887 void AutoMixedPrecisionImpl::RemoveAllowsetWithFp32(
1888     absl::flat_hash_set<int>* allow_set) const {
1889   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1890     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1891     if (f16_allowlist_.count(root.node->op()) && allow_set->count(root_idx) &&
1892         (!SupportsF16DataType(root) || IsQuantized(root))) {
1893       auto erased = allow_set->erase(root_idx);
1894       if (VLOG_IS_ON(2) && erased) {
1895         VLOG(2) << "UnPainting type " << root.type_attr.DebugString()
1896                 << " of node " << root.node->name() << " ALLOW because its op "
1897                 << root.node->op() << " is not support F16 DataType";
1898       }
1899     }
1900   }
1901 }
1902 
1903 // Forces NextIteration nodes and their output Merge node(s) to have the same
1904 // color. Specifically, it removes them all from allow_set if any of the Merge
1905 // nodes is not in allow_set, otherwise it adds the NextIteration node to
1906 // allow_set.
ForceColorMatchOnRecurrentEdges(absl::flat_hash_set<int> * allow_set) const1907 Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
1908     absl::flat_hash_set<int>* allow_set) const {
1909   for (const NodeDef& node : graph_->node()) {
1910     if (node.op() == "NextIteration") {
1911       GraphView::OutputPort output_port(&node, 0);
1912       const auto& fanout = graph_view_.GetFanout(output_port);
1913       std::vector<int> merge_idxs;
1914       merge_idxs.reserve(fanout.size());
1915       bool any_merge_is_not_allow = false;
1916       for (const auto& output : fanout) {
1917         const NodeDef& merge_node = *output.node;
1918         if (merge_node.op() != "Merge") {
1919           return errors::FailedPrecondition(
1920               "Expected Merge node after NextIteration, got ", merge_node.op());
1921         }
1922         const absl::optional<int> maybe_merge_idx =
1923             graph_type_view_.GetNodeIndex(merge_node.name(), TypeAttrId("T"));
1924         if (!maybe_merge_idx.has_value()) {
1925           return errors::Internal("Type attribute T of Merge node ",
1926                                   merge_node.name(),
1927                                   " not found in graph view");
1928         }
1929         int merge_idx = maybe_merge_idx.value();
1930         merge_idxs.push_back(merge_idx);
1931         any_merge_is_not_allow =
1932             any_merge_is_not_allow || !allow_set->count(merge_idx);
1933       }
1934       const absl::optional<int> maybe_nextiter_idx =
1935           graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T"));
1936       if (!maybe_nextiter_idx.has_value()) {
1937         return errors::Internal("Type attribute T of NextIteration node ",
1938                                 node.name(), " not found in graph view");
1939       }
1940       int nextiter_idx = maybe_nextiter_idx.value();
1941       if (any_merge_is_not_allow) {
1942         for (int merge_idx : merge_idxs) {
1943           if (allow_set->erase(merge_idx)) {
1944             VLOG(2) << "Painting type T of Merge node "
1945                     << graph_type_view_.GetNode(merge_idx)->node->name()
1946                     << " DENY to match the color of its sibling Merge nodes "
1947                        "with common NextIteration node "
1948                     << node.name();
1949           }
1950         }
1951         if (allow_set->erase(nextiter_idx)) {
1952           VLOG(2) << "Painting type T of NextIteration node " << node.name()
1953                   << " DENY to match the color of its output Merge node(s)";
1954         }
1955       } else {
1956         if (allow_set->insert(nextiter_idx).second) {
1957           VLOG(2) << "Painting type T of NextIteration node " << node.name()
1958                   << " ALLOW to match the color of its output Merge node(s)";
1959         }
1960       }
1961     }
1962   }
1963   return OkStatus();
1964 }
1965 
1966 // Forces all of the given Tensor List nodes into the same color set.
ForceColorMatchBetweenTensorListOps(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,absl::flat_hash_set<int> * allow_set,absl::flat_hash_set<int> * deny_set) const1967 void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
1968     const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1969     absl::flat_hash_set<int>* allow_set,
1970     absl::flat_hash_set<int>* deny_set) const {
1971   bool any_deny = false;
1972   bool any_allow = false;
1973   std::vector<int> node_type_idxs;
1974   node_type_idxs.reserve(tensor_list_nodes.size());
1975   for (const NodeDef* node : tensor_list_nodes) {
1976     const NodeTypeId& node_type = *GetTensorListFloat32NodeTypeId(*node);
1977     const absl::optional<int> maybe_node_type_idx =
1978         graph_type_view_.GetNodeIndex(node_type);
1979     DCHECK(maybe_node_type_idx.has_value())
1980         << "Type attribute " << node_type.type_attr.DebugString() << " of node "
1981         << node->name() << " not found in graph view";
1982     node_type_idxs.push_back(maybe_node_type_idx.value());
1983   }
1984   for (int node_type_idx : node_type_idxs) {
1985     if (deny_set->count(node_type_idx)) {
1986       any_deny = true;
1987       break;
1988     } else if (allow_set->count(node_type_idx)) {
1989       any_allow = true;
1990     }
1991   }
1992   if (!any_deny && !any_allow) return;
1993   for (int node_type_idx : node_type_idxs) {
1994     const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx);
1995     VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of "
1996             << node_type.node->op() << " node " << node_type.node->name() << " "
1997             << (any_deny ? "DENY" : "ALLOW")
1998             << " because at least one of its siblings is "
1999             << (any_deny ? "DENY" : "ALLOW");
2000     if (any_deny) {
2001       allow_set->erase(node_type_idx);
2002       deny_set->insert(node_type_idx);
2003     } else {
2004       allow_set->insert(node_type_idx);
2005     }
2006   }
2007 }
2008 
NodeImplicitlyReadsNonResourceVariable(const NodeDef & node) const2009 bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
2010     const NodeDef& node) const {
2011   if (node.op() == "Identity" || node.op() == "Enter") {
2012     GraphView::InputPort node_input(&node, 0);
2013     MutableGraphView::OutputPort prev_output =
2014         graph_view_.GetRegularFanin(node_input);
2015     const NodeDef* input = prev_output.node;
2016     if (input && ((node.op() == "Identity" && (input->op() == "Variable" ||
2017                                                input->op() == "VariableV2")) ||
2018                   (node.op() == "Enter" &&
2019                    NodeImplicitlyReadsNonResourceVariable(*input)))) {
2020       return true;
2021     }
2022   }
2023   return false;
2024 }
2025 
2026 // This adds existing Cast nodes to allow_set if all of their outputs are allow,
2027 // avoiding the need to add a new Cast node after an existing Cast.
MakeCastsAllowIfAllOutputsAllow(absl::flat_hash_set<int> * allow_set) const2028 void AutoMixedPrecisionImpl::MakeCastsAllowIfAllOutputsAllow(
2029     absl::flat_hash_set<int>* allow_set) const {
2030   int num_nodes_preop = graph_->node_size();
2031   for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
2032     NodeDef* node = graph_->mutable_node(node_idx);
2033     NodeTypeId node_type(node, TypeAttrId("DstT"));
2034     if (node->op() != "Cast" || !IsFloat32(node_type)) {
2035       continue;
2036     }
2037     bool all_fanouts_allow = true;
2038     MutableGraphView::OutputPort src(node, 0);
2039     const auto& fanout = graph_view_.GetFanout(src);
2040     for (const MutableGraphView::InputPort& dst : fanout) {
2041       TypeAttrId dst_type_attr =
2042           node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
2043       const absl::optional<int> maybe_dst_type_idx =
2044           graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
2045       DCHECK(maybe_dst_type_idx.has_value())
2046           << "Type attribute " << dst_type_attr.DebugString() << " of node "
2047           << dst.node->name() << " not found in graph view";
2048       int dst_type_idx = maybe_dst_type_idx.value();
2049       bool dst_is_allow = allow_set->count(dst_type_idx);
2050       if (!dst_is_allow) {
2051         all_fanouts_allow = false;
2052         break;
2053       }
2054     }
2055     if (!fanout.empty() && all_fanouts_allow) {
2056       const absl::optional<int> maybe_node_type_idx =
2057           graph_type_view_.GetNodeIndex(node_type);
2058       DCHECK(maybe_node_type_idx.has_value())
2059           << "Type attribute " << node_type.type_attr.DebugString()
2060           << " of node " << node_type.node->name()
2061           << " not found in graph view";
2062       int node_type_idx = maybe_node_type_idx.value();
2063       allow_set->insert(node_type_idx);
2064     }
2065   }
2066 }
2067 
2068 // Insert a Cast op at the output of a node.
2069 // CastType indicates the type of inserted Cast op
2070 //   FP16: cast to float16
2071 //   FP32: cast to float32
2072 //   AUTO: cast to a data type that matches the fanout data type
InsertCastNodeAtFanout(const absl::flat_hash_set<int> & allow_set,const bool src_is_allow,const CastType & cast_type,MutableGraphView::OutputPort & src)2073 StatusOr<NodeDef*> AutoMixedPrecisionImpl::InsertCastNodeAtFanout(
2074     const absl::flat_hash_set<int>& allow_set, const bool src_is_allow,
2075     const CastType& cast_type, MutableGraphView::OutputPort& src) {
2076   NodeDef* added_cast_node = nullptr;
2077   // Note: This is copied so that edges can be modified inside the loop.
2078   auto fanout = graph_view_.GetFanout(src);
2079   for (const MutableGraphView::InputPort& dst : fanout) {
2080     TypeAttrId dst_type_attr =
2081         node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
2082     const absl::optional<int> maybe_dst_type_idx =
2083         graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
2084     if (!maybe_dst_type_idx.has_value()) {
2085       return errors::Internal("Type attribute ", dst_type_attr.DebugString(),
2086                               " of ", dst.node->op(), " node ",
2087                               dst.node->name(), " not found in graph view");
2088     }
2089     int dst_type_idx = maybe_dst_type_idx.value();
2090     bool dst_is_allow = allow_set.count(dst_type_idx);
2091     bool to_f16 = false;
2092     bool should_cast = false;
2093     switch (cast_type) {
2094       case CastType::AUTO:
2095         if (src_is_allow != dst_is_allow) {
2096           to_f16 = dst_is_allow;
2097           should_cast = true;
2098         }
2099         break;
2100       case CastType::FP16:
2101         to_f16 = true;
2102         should_cast = true;
2103         break;
2104       case CastType::FP32:
2105         to_f16 = false;
2106         should_cast = true;
2107         break;
2108       default:
2109         return errors::Internal("Invalid Cast Type: ",
2110                                 static_cast<int>(cast_type));
2111     }
2112 
2113     if (!should_cast) continue;
2114     if (added_cast_node == nullptr) {
2115       VLOG(1) << "Inserting cast to "
2116               << (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT") << " at "
2117               << src.node->op() << " " << src.node->name() << ":"
2118               << src.port_id;
2119       added_cast_node = graph_view_.AddNode(
2120           BuildCastNode(src, dst, to_f16, src.node->device()));
2121       if (to_f16 && !IsConstant(*src.node) && !IsVariable(*src.node) &&
2122           !NodeImplicitlyReadsNonResourceVariable(*src.node)) {
2123         ++num_nonvar_casts_to_f16_;
2124       }
2125     }
2126     TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
2127         dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
2128   }
2129   return added_cast_node;
2130 }
2131 
2132 // Get the destination data type of a cast op. Return error if the node is not
2133 // a Cast op.
GetCastToType(const NodeDef * node) const2134 StatusOr<DataType> AutoMixedPrecisionImpl::GetCastToType(
2135     const NodeDef* node) const {
2136   CHECK_EQ(node->op(), "Cast")  // Crash OK
2137       << "Node " << node->name() << " is not a Cast op";
2138   return node->attr().at("DstT").type();
2139 }
2140 
2141 // Collect the output ports of a node based on a type attribute and append them
2142 // to a vector.
2143 //   Input: type_attr
2144 //   Input: node
2145 //   Output: output_ports
CollectOutputPorts(const TypeAttrId & type_attr,NodeDef * node,std::vector<MutableGraphView::OutputPort> & output_ports) const2146 void AutoMixedPrecisionImpl::CollectOutputPorts(
2147     const TypeAttrId& type_attr, NodeDef* node,
2148     std::vector<MutableGraphView::OutputPort>& output_ports) const {
2149   for (int port_id : node_type_map_.GetOutputPorts(*node, type_attr)) {
2150     output_ports.emplace_back(node, port_id);
2151   }
2152 }
2153 
2154 // Changes all allow-painted type attributes to DT_HALF or DT_BFLOAT16, and
2155 // inserts Cast nodes at node outputs for all edges that connect
2156 // allow-painted <-> non-allow-painted type attributes.
ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int> & allow_set)2157 Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
2158     const absl::flat_hash_set<int>& allow_set) {
2159   int num_nodes_changed = 0;
2160   const int num_nodes_preop = graph_->node_size();
2161 
2162   bool emulate_f16 = false;
2163   if (mode_ == AutoMixedPrecisionMode::CPU) {
2164     TF_CHECK_OK(
2165         ReadBoolFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_EMULATE_FP16",
2166                            /*default_val=*/true, &emulate_f16));
2167   }
2168 
2169   VLOG(1) << "Setting emulate_f16 = " << emulate_f16;
2170 
2171   for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
2172     NodeDef* node = graph_->mutable_node(node_idx);
2173     for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(*node)) {
2174       const absl::optional<int> maybe_node_type_idx =
2175           graph_type_view_.GetNodeIndex(node->name(), type_attr);
2176       if (!maybe_node_type_idx.has_value()) {
2177         return errors::Internal("Type attribute ", type_attr.DebugString(),
2178                                 " of ", node->op(), " node ", node->name(),
2179                                 " not found in graph view");
2180       }
2181       int node_type_idx = maybe_node_type_idx.value();
2182       if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
2183       bool src_is_allow = allow_set.count(node_type_idx);
2184 
2185       // Include output ports of fp32 nodes, real fp16 nodes,
2186       // and the fp16 Cast nodes at the fanout of emulated fp16 ops.
2187       std::vector<MutableGraphView::OutputPort> output_ports;
2188 
2189       if (src_is_allow) {
2190         if (emulate_f16) {
2191           // For emulated fp16 op, we do not change the op type but instead
2192           // insert fp32 Cast at the fanin and fp16 Cast at the fanout
2193           for (int port_id : node_type_map_.GetInputPorts(*node, type_attr)) {
2194             VLOG(2) << "Cast to F32 at fanin of node " << node->name() << ":"
2195                     << port_id;
2196             MutableGraphView::InputPort dst(node, port_id);
2197             MutableGraphView::OutputPort src = graph_view_.GetRegularFanin(dst);
2198             NodeDef* added_cast_node = graph_view_.AddNode(
2199                 BuildCastNode(src, dst, /*to_f16=*/false, src.node->device()));
2200             VLOG(1) << "Inserting cast to DT_FLOAT at " << src.node->op() << " "
2201                     << src.node->name() << ":" << src.port_id;
2202             TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
2203                 dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
2204           }
2205           // Cast to fp16 at outputs
2206           for (int port_id : node_type_map_.GetOutputPorts(*node, type_attr)) {
2207             MutableGraphView::OutputPort src(node, port_id);
2208             VLOG(2) << "Cast to F16 at fanout of node " << node->name() << ":"
2209                     << port_id;
2210             TF_ASSIGN_OR_RETURN(NodeDef * added_cast_node,
2211                                 InsertCastNodeAtFanout(allow_set, src_is_allow,
2212                                                        CastType::FP16, src));
2213             if (added_cast_node != nullptr) {
2214               output_ports.emplace_back(added_cast_node, /*port_id=*/0);
2215             }
2216           }
2217         } else {
2218           VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
2219                   << node->op() << " node " << node->name() << " to "
2220                   << DataTypeString(target_dtype_);
2221           if (!SetDataType(node, type_attr, target_dtype_)) {
2222             return errors::Internal("Failed to set type attribute");
2223           }
2224           ++num_nodes_changed;
2225           CollectOutputPorts(type_attr, node, output_ports);
2226         }
2227       } else {
2228         CollectOutputPorts(type_attr, node, output_ports);
2229       }
2230 
2231       // If the fanouts require a different data type from the output of the
2232       // current node, insert a Cast op.
2233       for (auto output_port : output_ports) {
2234         VLOG(2) << "Cast to required data type at fanout of node "
2235                 << output_port.node->name() << ":" << output_port.port_id;
2236         TF_RETURN_IF_ERROR(InsertCastNodeAtFanout(allow_set, src_is_allow,
2237                                                   CastType::AUTO, output_port)
2238                                .status());
2239       }
2240     }
2241   }
2242 
2243   // Use Python type names (e.g. float16) instead of C++ type names (e.g. half)
2244   // since many Python users will see this message.
2245   const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16";
2246   LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
2247             << " nodes to " << type_str << " precision using "
2248             << num_nonvar_casts_to_f16_ << " cast(s) to " << type_str
2249             << " (excluding Const and Variable casts)";
2250   return OkStatus();
2251 }
2252 
GetNumGPUs(const Cluster & cluster)2253 int GetNumGPUs(const Cluster& cluster) {
2254   if (ShouldSimulateGpu()) {
2255     return 1;
2256   }
2257   auto devices = cluster.GetDevices();
2258   int num_gpus = 0;
2259   for (const auto& device : devices) {
2260     const DeviceProperties& device_properties = device.second;
2261     if (device_properties.type() == "GPU" &&
2262         (ShouldIgnorePerformance() || HasFastFP16Support(device_properties))) {
2263       num_gpus++;
2264     }
2265   }
2266   return num_gpus;
2267 }
2268 
2269 }  // end namespace
2270 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)2271 Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
2272                                     GraphDef* output) {
2273   if (cluster == nullptr) {
2274     return errors::InvalidArgument("cluster == nullptr");
2275   }
2276 
2277 #if !defined(INTEL_MKL)
2278   if (mode_ == AutoMixedPrecisionMode::BF16) {
2279     return errors::Unimplemented(
2280         "The auto_mixed_precision_onednn_bfloat16 optimizer cannot be used "
2281         "since this build of TensorFlow is not compiled with oneDNN support "
2282         "for bfloat16. "
2283         "For information on oneDNN builds, see: "
2284         "https://software.intel.com/en-us/articles/intel-optimization-for-"
2285         "tensorflow-installation-guide");
2286   }
2287 #endif  // INTEL_MKL
2288 
2289   // Start by copying input graph to output.
2290   *output = item.graph;
2291 
2292   int num_gpus = GetNumGPUs(*cluster);
2293   if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) {
2294     // AutoMixedPrecision is currently only tuned for GPU.
2295     LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
2296                  << " graph optimizer";
2297     return OkStatus();
2298   }
2299 
2300   if (num_gpus >= 1 && mode_ == AutoMixedPrecisionMode::BF16) {
2301     LOG(WARNING) << "Note: GPUs detected. Using " << name()
2302                  << " graph optimizer configured for BFloat16 on CPUs";
2303   }
2304 
2305   // Optimize the output graph in-place.
2306   AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
2307                                    item.id, mode_);
2308   if (item.id == "tf_graph") {
2309     LOG(INFO) << "Running " << name() << " graph optimizer";
2310   } else {
2311     VLOG(1) << "Running " << name() << " graph optimizer on " << item.id;
2312   }
2313   Status status = optimizer.Optimize();
2314   if (!status.ok()) {
2315     // Restore the original graph.
2316     *output = item.graph;
2317     LOG(WARNING) << name() << " graph optimizer FAILED: " << status.ToString();
2318   }
2319   return status;
2320 }
2321 
2322 }  // end namespace grappler
2323 }  // end namespace tensorflow
2324