1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
17
18 #include <fstream>
19 #include <memory>
20 #include <unordered_map>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/grappler/clusters/cluster.h"
32 #include "tensorflow/core/grappler/costs/virtual_placer.h"
33 #include "tensorflow/core/grappler/devices.h"
34 #include "tensorflow/core/grappler/grappler_item.h"
35 #include "tensorflow/core/grappler/mutable_graph_view.h"
36 #include "tensorflow/core/grappler/op_types.h"
37 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h"
38 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
39 #include "tensorflow/core/grappler/utils.h"
40 #include "tensorflow/core/lib/io/path.h"
41 #include "tensorflow/core/lib/strings/numbers.h"
42 #include "tensorflow/core/lib/strings/str_util.h"
43 #include "tensorflow/core/lib/strings/strcat.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/util/env_var.h"
47
48 namespace tensorflow {
49 namespace grappler {
50 namespace {
51
ShouldSimulateGpu()52 bool ShouldSimulateGpu() {
53 bool is_enabled = [] {
54 bool ret = false;
55 string var;
56 TF_CHECK_OK(ReadStringFromEnvVar(
57 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU", "", &var));
58 TF_CHECK_OK(
59 ReadBoolFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU",
60 /*default_val=*/false, &ret));
61 return ret;
62 }();
63 return is_enabled;
64 }
65
66 #if GOOGLE_CUDA
67 const std::pair<int, int> kMinGPUArch = {7, 0};
68 #else
69 const std::pair<int, int> kMinGPUArch = {0, 0};
70 #endif
71
72 const char kSuffix[] = "AutoMixedPrecision";
73 const char kCastToFp16[] = "CastToFp16";
74 const char kCastToBf16[] = "CastToBf16";
75 const char kCastToFp32[] = "CastToFp32";
76
77 #if GOOGLE_CUDA
78 // Returns the GPU architecture (compute capability) as a (major, minor) pair.
GetDeviceGPUArch(const DeviceProperties & device_properties)79 std::pair<int, int> GetDeviceGPUArch(
80 const DeviceProperties& device_properties) {
81 if (device_properties.type() != "GPU") return {0, 0};
82 string arch_str = device_properties.environment().at("architecture");
83 std::vector<string> split_arch_str = str_util::Split(arch_str, '.');
84 if (split_arch_str.empty()) {
85 return {0, 0};
86 }
87
88 int major, minor;
89 if (!strings::safe_strto32(split_arch_str[0], &major)) {
90 return {0, 0};
91 }
92
93 if (split_arch_str.size() > 1) {
94 if (strings::safe_strto32(split_arch_str[1], &minor)) {
95 return {major, minor};
96 } else {
97 return {0, 0};
98 }
99 } else {
100 return {major, 0};
101 }
102 }
103 #endif
104
105 // Returns true if FP16Support is valid
106 // For CUDA, We compare the GPUArch with the kMinGPUArch, if GPUArch is >= min,
107 // return true. For AMD the corresponding gfx arch string for the detected AMD
108 // GPU is in the list for FP16 supported compute. Returns false otherwise.
HasFastFP16Support(const DeviceProperties & props)109 bool HasFastFP16Support(const DeviceProperties& props) {
110 #if GOOGLE_CUDA
111 return GetDeviceGPUArch(props) >= kMinGPUArch;
112 #elif TENSORFLOW_USE_ROCM
113 absl::flat_hash_set<std::string> FP16SupportedDevices = {{"gfx906"},
114 {"gfx908"}};
115 std::string gcnArchName = props.environment().at("architecture");
116 std::vector<std::string> gpu_arch = absl::StrSplit(gcnArchName, ":");
117 return !gpu_arch.empty() && FP16SupportedDevices.contains(gpu_arch[0]);
118 #endif
119 return ShouldSimulateGpu();
120 }
121
122 // Instances of this class represent unique type attribute identifiers within a
123 // node. It handles regular type attributes, list type attributes (where
124 // type_index is set to the index in the type list), and fixed types.
125 struct TypeAttrId {
126 static constexpr int kSingleType = -1;
127
TypeAttrIdtensorflow::grappler::__anon32a4a5c20111::TypeAttrId128 explicit TypeAttrId(const string& _attr_name, int _type_index = kSingleType)
129 : attr_name(_attr_name),
130 type_index(_type_index),
131 fixed_type(DT_INVALID) {}
132
TypeAttrIdtensorflow::grappler::__anon32a4a5c20111::TypeAttrId133 explicit TypeAttrId(DataType _fixed_type)
134 : attr_name(), type_index(kSingleType), fixed_type(_fixed_type) {}
135
operator ==tensorflow::grappler::__anon32a4a5c20111::TypeAttrId136 bool operator==(const TypeAttrId& other) const {
137 return attr_name == other.attr_name && type_index == other.type_index &&
138 fixed_type == other.fixed_type;
139 }
140
operator <tensorflow::grappler::__anon32a4a5c20111::TypeAttrId141 bool operator<(const TypeAttrId& other) const {
142 return std::make_tuple(attr_name, type_index, fixed_type) <
143 std::make_tuple(other.attr_name, other.type_index, other.fixed_type);
144 }
145
146 template <typename H>
AbslHashValue(H h,const TypeAttrId & ta)147 friend H AbslHashValue(H h, const TypeAttrId& ta) {
148 return H::combine(std::move(h), ta.attr_name, ta.type_index, ta.fixed_type);
149 }
150
DebugStringtensorflow::grappler::__anon32a4a5c20111::TypeAttrId151 string DebugString() const {
152 if (!attr_name.empty()) {
153 if (type_index == kSingleType) {
154 return attr_name;
155 } else {
156 return strings::StrCat(attr_name, "[", type_index, "]");
157 }
158 } else {
159 return tensorflow::DataTypeString(fixed_type);
160 }
161 }
162
163 string attr_name;
164 // If attr_name is a list(type), this is the index into the list. Otherwise
165 // this is kSingleType.
166 int type_index;
167 DataType fixed_type;
168 };
169
170 // Returns the data type of the given type attribute, or DT_INVALID if the type
171 // attribute is invalid.
GetDataType(const NodeDef & node,const TypeAttrId & type_attr)172 DataType GetDataType(const NodeDef& node, const TypeAttrId& type_attr) {
173 if (type_attr.attr_name.empty()) {
174 return type_attr.fixed_type;
175 }
176 if (!node.attr().count(type_attr.attr_name)) {
177 return DT_INVALID;
178 }
179 const AttrValue& attr_value = node.attr().at(type_attr.attr_name);
180 if (type_attr.type_index == TypeAttrId::kSingleType) {
181 return attr_value.type();
182 } else {
183 if (type_attr.type_index < 0 ||
184 type_attr.type_index >= attr_value.list().type_size()) {
185 return DT_INVALID;
186 }
187 return attr_value.list().type(type_attr.type_index);
188 }
189 }
190
191 // Sets the data type of the given type attribute. Returns false if the type
192 // attribute is invalid, otherwise true.
SetDataType(NodeDef * node,const TypeAttrId & type_attr,DataType type)193 bool SetDataType(NodeDef* node, const TypeAttrId& type_attr, DataType type) {
194 if (type_attr.attr_name.empty() || !node->attr().count(type_attr.attr_name)) {
195 return false;
196 }
197 AttrValue& attr_value = node->mutable_attr()->at(type_attr.attr_name);
198 if (type_attr.type_index == TypeAttrId::kSingleType) {
199 attr_value.set_type(type);
200 } else {
201 if (type_attr.type_index < 0 ||
202 type_attr.type_index >= attr_value.list().type_size()) {
203 return false;
204 }
205 attr_value.mutable_list()->set_type(type_attr.type_index, type);
206 }
207 return true;
208 }
209
ArgDefIndexes(const NodeDef & node,int arg_idx,const OpDef::ArgDef & arg_def)210 std::vector<std::pair<int, int>> ArgDefIndexes(const NodeDef& node, int arg_idx,
211 const OpDef::ArgDef& arg_def) {
212 std::vector<std::pair<int, int>> argdef_inds;
213 if (!arg_def.type_list_attr().empty()) {
214 int num_types = node.attr().at(arg_def.type_list_attr()).list().type_size();
215 for (int type_idx = 0; type_idx < num_types; ++type_idx) {
216 argdef_inds.push_back({arg_idx, type_idx});
217 }
218 } else {
219 int num_repeat = 1;
220 if (node.attr().count(arg_def.number_attr())) {
221 num_repeat = node.attr().at(arg_def.number_attr()).i();
222 }
223 argdef_inds.insert(argdef_inds.end(), num_repeat, {arg_idx, -1});
224 }
225 return argdef_inds;
226 }
227
228 // Returns a pair (arg_index, type_index) for each input to the node, where
229 // arg_index is the index of the input_arg in op_def and type_index is the index
230 // of the type in type_list_attr (only defined for list arguments).
InputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)231 std::vector<std::pair<int, int>> InputPortArgDefIndexes(const NodeDef& node,
232 const OpDef& op_def) {
233 std::vector<std::pair<int, int>> argdef_inds;
234 argdef_inds.reserve(op_def.input_arg_size()); // Final size may differ.
235 for (int arg_idx = 0; arg_idx < op_def.input_arg_size(); ++arg_idx) {
236 const OpDef::ArgDef& arg_def = op_def.input_arg(arg_idx);
237 auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
238 argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
239 arg_results.end());
240 }
241 return argdef_inds;
242 }
243
244 // Returns a pair (arg_index, type_index) for each output to the node, where
245 // arg_index is the index of the output_arg in op_def and type_index is the
246 // index of the type in type_list_attr (only defined for list arguments).
OutputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)247 std::vector<std::pair<int, int>> OutputPortArgDefIndexes(const NodeDef& node,
248 const OpDef& op_def) {
249 std::vector<std::pair<int, int>> argdef_inds;
250 argdef_inds.reserve(op_def.output_arg_size()); // Final size may differ.
251 for (int arg_idx = 0; arg_idx < op_def.output_arg_size(); ++arg_idx) {
252 const OpDef::ArgDef& arg_def = op_def.output_arg(arg_idx);
253 auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
254 argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
255 arg_results.end());
256 }
257 return argdef_inds;
258 }
259
GetTypeAttrId(const OpDef::ArgDef & arg_def,int arg_type_index)260 TypeAttrId GetTypeAttrId(const OpDef::ArgDef& arg_def, int arg_type_index) {
261 if (!arg_def.type_list_attr().empty()) {
262 return TypeAttrId(arg_def.type_list_attr(), arg_type_index);
263 } else if (!arg_def.type_attr().empty()) {
264 return TypeAttrId(arg_def.type_attr());
265 } else {
266 return TypeAttrId(arg_def.type());
267 }
268 }
269
NonControlInputs(const NodeDef & node)270 std::vector<int> NonControlInputs(const NodeDef& node) {
271 std::vector<int> pos;
272 for (int i = 0; i < node.input_size(); i++) {
273 if (!IsControlInput(node.input(i))) {
274 pos.push_back(i);
275 }
276 }
277 return pos;
278 }
279
280 // A utility class to lookup node type attributes and type attribute <->
281 // input/output port mappings.
282 class NodeTypeAttrMap {
283 public:
NodeTypeAttrMap()284 NodeTypeAttrMap() {}
285
NodeTypeAttrMap(const GraphDef & graph)286 explicit NodeTypeAttrMap(const GraphDef& graph) { TF_CHECK_OK(Init(graph)); }
287
Init(const GraphDef & graph)288 Status Init(const GraphDef& graph) {
289 if (graph_ != nullptr) {
290 return errors::InvalidArgument("NodeTypeAttrMap is already initialized.");
291 }
292 graph_ = &graph;
293 function_library_.reset(
294 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
295 for (const NodeDef& node : graph.node()) {
296 TF_RETURN_IF_ERROR(AddNode(node));
297 }
298 return OkStatus();
299 }
300
is_initialized() const301 bool is_initialized() const { return graph_ != nullptr; }
302
303 // Returns the set of all type attributes in the given node.
GetTypeAttrs(const NodeDef & node) const304 absl::flat_hash_set<TypeAttrId> GetTypeAttrs(const NodeDef& node) const {
305 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
306 absl::flat_hash_set<TypeAttrId> type_attrs;
307 const auto iter = type2io_.find(&node);
308 CHECK(iter != type2io_.end()); // Crash Ok
309 for (const auto& key_value : iter->second) {
310 type_attrs.insert(key_value.first);
311 }
312 return type_attrs;
313 }
314
GetInputPorts(const NodeDef & node,const TypeAttrId & type_attr) const315 const absl::flat_hash_set<int>& GetInputPorts(
316 const NodeDef& node, const TypeAttrId& type_attr) const {
317 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
318 return type2io_.at(&node).at(type_attr).first;
319 }
320
GetOutputPorts(const NodeDef & node,const TypeAttrId & type_attr) const321 const absl::flat_hash_set<int>& GetOutputPorts(
322 const NodeDef& node, const TypeAttrId& type_attr) const {
323 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
324 return type2io_.at(&node).at(type_attr).second;
325 }
326
GetInputTypeAttr(const NodeDef & node,int port) const327 TypeAttrId GetInputTypeAttr(const NodeDef& node, int port) const {
328 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
329 const auto iter = io2type_.find(&node);
330 DCHECK(iter != io2type_.end())
331 << "Node " << node.name() << " doesn't exist in a graph";
332 auto type_vec = io2type_.at(&node).first;
333 CHECK_GE(port, 0); // Crash Ok
334 CHECK_LT(port, type_vec.size()); // Crash Ok
335 return type_vec[port];
336 }
337
GetOutputTypeAttr(const NodeDef & node,int port) const338 TypeAttrId GetOutputTypeAttr(const NodeDef& node, int port) const {
339 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
340 auto type_vec = io2type_.at(&node).second;
341 CHECK_GE(port, 0); // Crash Ok
342 CHECK_LT(port, type_vec.size()); // Crash Ok
343 return type_vec[port];
344 }
345
346 private:
AddNode(const NodeDef & node)347 Status AddNode(const NodeDef& node) {
348 const OpDef* op_def_ptr = nullptr;
349 TF_RETURN_IF_ERROR(function_library_->LookUpOpDef(node.op(), &op_def_ptr));
350 const OpDef& op_def = *op_def_ptr;
351 auto& type2io_entry = type2io_[&node];
352 auto& io2type_entry = io2type_[&node];
353 auto input_arg_inds = InputPortArgDefIndexes(node, op_def);
354 if (NonControlInputs(node).size() != input_arg_inds.size()) {
355 return errors::InvalidArgument(
356 "Expected ", node.op(), " node ", node.name(), " to have ",
357 input_arg_inds.size(), " non-control input(s), but got ",
358 node.input_size());
359 }
360 // Note that the mappings generated here include inputs/outputs with fixed
361 // types. This makes the mappings complete (all inputs and outputs are
362 // included), and allows the graph rewriter to propagate deny paint
363 // from/through ops with fixed types.
364 io2type_entry.first.reserve(input_arg_inds.size());
365 for (int i = 0; i < static_cast<int>(input_arg_inds.size()); ++i) {
366 const auto& arg_inds = input_arg_inds[i];
367 const OpDef::ArgDef& arg_def = op_def.input_arg(arg_inds.first);
368 TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
369 if (!type_attr.attr_name.empty() &&
370 !node.attr().count(type_attr.attr_name)) {
371 return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
372 " is not present in node ", node.name());
373 }
374 type2io_entry[type_attr].first.insert(i);
375 io2type_entry.first.push_back(type_attr);
376 }
377
378 auto output_arg_inds = OutputPortArgDefIndexes(node, op_def);
379 io2type_entry.second.reserve(output_arg_inds.size());
380 for (int i = 0; i < static_cast<int>(output_arg_inds.size()); ++i) {
381 const auto& arg_inds = output_arg_inds[i];
382 const OpDef::ArgDef& arg_def = op_def.output_arg(arg_inds.first);
383 TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
384 if (!type_attr.attr_name.empty() &&
385 !node.attr().count(type_attr.attr_name)) {
386 return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
387 " is not present in node ", node.name());
388 }
389 type2io_entry[type_attr].second.insert(i);
390 io2type_entry.second.push_back(type_attr);
391 }
392
393 // Also ensure that type attributes that aren't associated with any inputs
394 // or outputs (e.g., StackV2's elem_type) are added to the map.
395 for (const auto& attr : node.attr()) {
396 const string& attr_name = attr.first;
397 if (!attr_name.empty() && attr_name[0] == '_') continue;
398 const AttrValue& attr_value = attr.second;
399 const OpDef::AttrDef* attr_def = FindAttr(attr_name, op_def);
400 if (!attr_def) {
401 return errors::InvalidArgument("AttrDef not found for attribute ",
402 attr_name, " of node ", node.name());
403 }
404 if (attr_def->type() == "type") {
405 type2io_entry[TypeAttrId(attr_name)];
406 } else if (attr_def->type() == "list(type)") {
407 for (int i = 0; i < attr_value.list().type_size(); ++i) {
408 type2io_entry[TypeAttrId(attr_name, i)];
409 }
410 }
411 }
412 return OkStatus();
413 }
414
415 // WARN: `graph_` must outlive this object (node pointers must remain valid).
416 const GraphDef* graph_ = nullptr; // do not own
417 std::unique_ptr<FunctionLibraryDefinition> function_library_;
418
419 typedef absl::flat_hash_set<int> IntSet;
420 // Maps a type attr id -> (input port set, output port set)
421 typedef absl::flat_hash_map<TypeAttrId, std::pair<IntSet, IntSet>> Type2IOMap;
422 // Maps a node -> type attr mapping
423 absl::flat_hash_map<const NodeDef*, Type2IOMap> type2io_;
424 // Maps a port -> type attr id
425 typedef std::vector<TypeAttrId> TypeAttrIdVec;
426 // Maps a node -> (input port mapping, output port mapping)
427 absl::flat_hash_map<const NodeDef*, std::pair<TypeAttrIdVec, TypeAttrIdVec>>
428 io2type_;
429 };
430
431 struct NodeTypeId {
NodeTypeIdtensorflow::grappler::__anon32a4a5c20111::NodeTypeId432 NodeTypeId(const NodeDef* _node, const TypeAttrId& _type_attr)
433 : node(_node), type_attr(_type_attr) {}
434
435 const NodeDef* node;
436 TypeAttrId type_attr;
437
operator ==tensorflow::grappler::__anon32a4a5c20111::NodeTypeId438 bool operator==(const NodeTypeId& other) const {
439 return node == other.node && type_attr == other.type_attr;
440 }
441
442 template <typename H>
AbslHashValue(H h,const NodeTypeId & nt)443 friend H AbslHashValue(H h, const NodeTypeId& nt) {
444 return H::combine(std::move(h), nt.node, nt.type_attr);
445 }
446 };
447
448 struct NodeTypeIdEdge {
NodeTypeIdEdgetensorflow::grappler::__anon32a4a5c20111::NodeTypeIdEdge449 NodeTypeIdEdge(const NodeTypeId& _src, const NodeTypeId& _dst)
450 : src(_src), dst(_dst) {}
451 NodeTypeId src;
452 NodeTypeId dst;
453 };
454
455 // TODO(benbarsdell): Investigate whether the existing GraphTopologyView can be
456 // used instead of this modified version.
457 // This is just like GraphTopologyView but with (NodeDef, TypeAttrId) pairs as
458 // the vertices instead of just NodeDef.
459 // For example, if node A has output A:0 with TypeAttrId 'T', and node B has
460 // input B:0 with TypeAttrId 'U', and input B:0 connects to output A:0, there
461 // will be an edge from (A, T) to (B, U).
462 class GraphTypeTopologyView {
463 public:
464 GraphTypeTopologyView() = default;
GraphTypeTopologyView(bool skip_invalid_edges)465 explicit GraphTypeTopologyView(bool skip_invalid_edges)
466 : skip_invalid_edges_(skip_invalid_edges) {}
467
468 // Initialize graph topology view from the graph. It's possible to pass
469 // additional edges that do not exist in a graph, but must be respected when
470 // computing graph topology. Example: Tensorflow runtime allows concurrent
471 // execution of dequeue/enqueue ops from the same queue resource, but we might
472 // want to enforce ordering between them for the purpose of graph analysis.
473 Status InitializeFromGraph(const GraphDef& graph,
474 const NodeTypeAttrMap& node_type_map);
475
476 Status AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges);
477
is_initialized() const478 bool is_initialized() const { return graph_ != nullptr; }
num_nodes() const479 int num_nodes() const { return num_nodes_; }
graph() const480 const GraphDef* graph() const { return graph_; }
481
482 // Returns true iff the node exists in the underlying graph.
483 bool HasNode(absl::string_view node_name, const TypeAttrId& type_attr) const;
484
485 // Finds a node by name or returns `nullptr` if it's not in the graph.
486 const NodeTypeId* GetNode(absl::string_view node_name,
487 const TypeAttrId& type_attr) const;
488 // Returns a node corresponding to the given node index.
489 const NodeTypeId* GetNode(int node_idx) const;
490
491 // Returns a node index for the given node name, if the name exists in the
492 // underlying graph. Otherwise returns empty optional.
493 const absl::optional<int> GetNodeIndex(absl::string_view node_name,
494 const TypeAttrId& type_attr) const;
495 // Returns a node index for the given node, if the node belongs to the
496 // underlying graph. Otherwise returns empty optional.
497 const absl::optional<int> GetNodeIndex(const NodeTypeId& node) const;
498
499 // Returns all the node indexes that are in the direct fanin of the given
500 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
501 const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const;
502 // Returns all the node indexes that are in the direct fanout of the given
503 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
504 const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const;
505
506 private:
507 // The key type used to uniquely identify a type attribute on a node.
508 struct NodeTypeKey : public std::pair<absl::string_view, TypeAttrId> {
509 typedef std::pair<absl::string_view, TypeAttrId> Base;
510
511 // Inherit the set of constructors.
512 using Base::pair;
513
514 template <typename H>
AbslHashValue(H h,const NodeTypeKey & nt)515 friend H AbslHashValue(H h, const NodeTypeKey& nt) {
516 return H::combine(std::move(h), nt.first, nt.second);
517 }
518 };
519
520 // If true, all invalid edges and inputs (srd, dst or input node not found in
521 // a graph) will be skipped, otherwise initialization will fail with error.
522 bool skip_invalid_edges_ = false;
523
524 // WARN: `graph_` must outlive this object and graph nodes must not be
525 // destructed, because node names captured with absl::string_view.
526 const GraphDef* graph_ = nullptr; // do not own
527 int num_nodes_ = 0;
528 std::vector<NodeTypeId> node_type_attrs_;
529 absl::flat_hash_map<absl::string_view, int> node_name_to_index_;
530 absl::flat_hash_map<NodeTypeKey, int> node_type_name_to_index_;
531
532 std::vector<absl::InlinedVector<int, 4>> fanins_;
533 std::vector<absl::InlinedVector<int, 2>> fanouts_;
534
535 // We need a valid reference to return from GetFanin/GetFanout if the
536 // `node_idx` argument is outside of the [0, num_nodes_) range.
537 absl::InlinedVector<int, 4> empty_fanin_;
538 absl::InlinedVector<int, 2> empty_fanout_;
539 };
540
541 template <typename T>
SortAndRemoveDuplicates(T * v)542 inline void SortAndRemoveDuplicates(T* v) {
543 std::sort(v->begin(), v->end());
544 v->erase(std::unique(v->begin(), v->end()), v->end());
545 }
546
InitializeFromGraph(const GraphDef & graph,const NodeTypeAttrMap & node_type_map)547 Status GraphTypeTopologyView::InitializeFromGraph(
548 const GraphDef& graph, const NodeTypeAttrMap& node_type_map) {
549 if (graph_ != nullptr) {
550 return errors::InvalidArgument(
551 "GraphTypeTopologyView is already initialized.");
552 }
553
554 graph_ = &graph;
555 int num_nodedefs = graph.node_size();
556 node_name_to_index_.rehash(num_nodedefs);
557
558 // Build maps from name to index.
559 node_type_attrs_.reserve(num_nodedefs); // Only approximate.
560 node_type_name_to_index_.rehash(num_nodedefs); // Only approximate.
561 for (int node_idx = 0; node_idx < num_nodedefs; ++node_idx) {
562 const NodeDef& node = graph.node(node_idx);
563 node_name_to_index_.emplace(node.name(), node_idx);
564
565 for (const TypeAttrId& type_attr : node_type_map.GetTypeAttrs(node)) {
566 int node_type_idx = node_type_attrs_.size();
567 node_type_name_to_index_.emplace(NodeTypeKey(node.name(), type_attr),
568 node_type_idx);
569 node_type_attrs_.emplace_back(&node, type_attr);
570 }
571 }
572 num_nodes_ = node_type_attrs_.size();
573 fanins_.resize(num_nodes_);
574 fanouts_.resize(num_nodes_);
575
576 // Add graph edges to the adjacency lists.
577 for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
578 const NodeTypeId& node_type = node_type_attrs_.at(node_type_idx);
579 auto input_ports =
580 node_type_map.GetInputPorts(*node_type.node, node_type.type_attr);
581 fanins_[node_type_idx].reserve(input_ports.size());
582 for (int port : input_ports) {
583 const string& input = node_type.node->input(port);
584 TensorId tensor = ParseTensorName(input);
585 const auto it = node_name_to_index_.find(tensor.node());
586 const bool valid_input = it != node_name_to_index_.end();
587
588 if (!valid_input) {
589 const string error_message = absl::StrCat(
590 "Non-existent input ", input, " in node ", node_type.node->name());
591 if (skip_invalid_edges_) {
592 VLOG(3) << "Skip error: " << error_message;
593 } else {
594 return errors::InvalidArgument(error_message);
595 }
596 }
597
598 if (valid_input) {
599 const int input_idx = it->second;
600 const NodeDef& input_node = graph_->node(input_idx);
601 TypeAttrId input_type_attr =
602 node_type_map.GetOutputTypeAttr(input_node, tensor.index());
603 const auto it2 = node_type_name_to_index_.find(
604 NodeTypeKey(input_node.name(), input_type_attr));
605 if (it2 == node_type_name_to_index_.end()) {
606 if (!skip_invalid_edges_) {
607 return errors::InvalidArgument("Did not find type attr ",
608 input_type_attr.DebugString(),
609 " in node ", input_node.name());
610 }
611 continue;
612 }
613 int input_node_type_idx = it2->second;
614 fanins_[node_type_idx].push_back(input_node_type_idx);
615 fanouts_[input_node_type_idx].push_back(node_type_idx);
616 }
617 }
618
619 // Dedup the input list while it's still hot in cache.
620 SortAndRemoveDuplicates(&fanins_[node_type_idx]);
621 }
622
623 // Dedup outputs for all the graph nodes.
624 for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
625 SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
626 }
627
628 return OkStatus();
629 }
630
AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges)631 Status GraphTypeTopologyView::AddEphemeralEdges(
632 absl::Span<const NodeTypeIdEdge> ephemeral_edges) {
633 // Add ephemeral edges to the adjacency lists.
634 for (const NodeTypeIdEdge& edge : ephemeral_edges) {
635 const auto src = node_name_to_index_.find(edge.src.node->name());
636 const bool valid_src = src != node_name_to_index_.end();
637
638 if (!valid_src) {
639 const string error_message =
640 absl::StrCat("Non-existent src node: ", edge.src.node->name());
641 if (skip_invalid_edges_) {
642 VLOG(0) << "Skip error: " << error_message;
643 } else {
644 return errors::InvalidArgument(error_message);
645 }
646 }
647
648 const auto dst = node_name_to_index_.find(edge.dst.node->name());
649 const bool valid_dst = dst != node_name_to_index_.end();
650
651 if (!valid_dst) {
652 const string error_message =
653 absl::StrCat("Non-existent dst node: ", edge.dst.node->name());
654 if (skip_invalid_edges_) {
655 VLOG(0) << "Skip error: " << error_message;
656 } else {
657 return errors::InvalidArgument(error_message);
658 }
659 }
660
661 if (valid_dst && valid_src) {
662 // TODO(benbarsdell): Check for failure.
663 int src_node_type_idx = node_type_name_to_index_.at(
664 NodeTypeKey(edge.src.node->name(), edge.src.type_attr));
665 int dst_node_type_idx = node_type_name_to_index_.at(
666 NodeTypeKey(edge.dst.node->name(), edge.dst.type_attr));
667 fanins_[dst_node_type_idx].push_back(src_node_type_idx);
668 fanouts_[src_node_type_idx].push_back(dst_node_type_idx);
669 }
670 }
671
672 // Dedup inputs and outputs for all the graph nodes.
673 for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
674 SortAndRemoveDuplicates(&fanins_[node_type_idx]);
675 SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
676 }
677
678 return OkStatus();
679 }
680
HasNode(absl::string_view node_name,const TypeAttrId & type_attr) const681 bool GraphTypeTopologyView::HasNode(absl::string_view node_name,
682 const TypeAttrId& type_attr) const {
683 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
684 NodeTypeKey key(node_name, type_attr);
685 const auto it = node_type_name_to_index_.find(key);
686 return it != node_type_name_to_index_.end();
687 }
688
GetNode(absl::string_view node_name,const TypeAttrId & type_attr) const689 const NodeTypeId* GraphTypeTopologyView::GetNode(
690 absl::string_view node_name, const TypeAttrId& type_attr) const {
691 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
692 NodeTypeKey key(node_name, type_attr);
693 const auto it = node_type_name_to_index_.find(key);
694 return it == node_type_name_to_index_.end()
695 ? nullptr
696 : &node_type_attrs_.at(it->second);
697 }
698
GetNode(int node_idx) const699 const NodeTypeId* GraphTypeTopologyView::GetNode(int node_idx) const {
700 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
701 DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range";
702 return &node_type_attrs_.at(node_idx);
703 }
704
GetNodeIndex(absl::string_view node_name,const TypeAttrId & type_attr) const705 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
706 absl::string_view node_name, const TypeAttrId& type_attr) const {
707 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
708 NodeTypeKey key(node_name, type_attr);
709 const auto it = node_type_name_to_index_.find(key);
710 DCHECK(it != node_type_name_to_index_.end())
711 << "Node doesn't exist in a graph";
712 return it == node_type_name_to_index_.end() ? absl::nullopt
713 : absl::make_optional(it->second);
714 }
715
GetNodeIndex(const NodeTypeId & node) const716 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
717 const NodeTypeId& node) const {
718 return GetNodeIndex(node.node->name(), node.type_attr);
719 }
720
GetFanin(int node_idx) const721 const absl::InlinedVector<int, 4>& GraphTypeTopologyView::GetFanin(
722 int node_idx) const {
723 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
724 const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
725 DCHECK(is_valid_node_idx) << "node_idx is out of range";
726 return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_;
727 }
728
GetFanout(int node_idx) const729 const absl::InlinedVector<int, 2>& GraphTypeTopologyView::GetFanout(
730 int node_idx) const {
731 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
732 const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
733 DCHECK(is_valid_node_idx) << "node_idx is out of range";
734 return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_;
735 }
736
737 enum class TypeTraversalDirection {
738 kFollowInputs,
739 kFollowOutputs,
740 kFollowInputsAndOutputs,
741 };
742
743 // Encapsulate DFS callbacks that will be called during the graph traversal.
744 //
745 // If non-empty, the `pre_order` and `post_order` functors will be called on
746 // each reachable node (including the `from` nodes) in pre and post order. If
747 // loops are found, the `on_back_edge` functor will be called on the
748 // corresponding back edges. Moreover, the pre and post order will assume that
749 // these back edges will be cut.
750 struct DfsTypeCallbacks {
751 DfsTypeCallbacks() = default;
DfsTypeCallbackstensorflow::grappler::__anon32a4a5c20111::DfsTypeCallbacks752 DfsTypeCallbacks(std::function<void(int)> pre, std::function<void(int)> post,
753 std::function<void(int, int)> back_edge)
754 : pre_order(std::move(pre)),
755 post_order(std::move(post)),
756 on_back_edge(std::move(back_edge)) {}
757
PreOrdertensorflow::grappler::__anon32a4a5c20111::DfsTypeCallbacks758 static DfsTypeCallbacks PreOrder(std::function<void(int)> pre) {
759 return DfsTypeCallbacks(std::move(pre), nullptr, nullptr);
760 }
761
PostOrdertensorflow::grappler::__anon32a4a5c20111::DfsTypeCallbacks762 static DfsTypeCallbacks PostOrder(std::function<void(int)> post) {
763 return DfsTypeCallbacks(nullptr, std::move(post), nullptr);
764 }
765
766 std::function<void(int)> pre_order;
767 std::function<void(int)> post_order;
768 std::function<void(int, int)> on_back_edge;
769 };
770
771 // Encapsulate DFS predicates for traversing the graph.
772 //
773 // The `enter` predicate decides if traversal should enter the node, and the
774 // `advance` predicate decides if the traversal should follow inputs/outputs
775 // from the node.
776 //
777 // If predicates are empty (default initialized), it's assumed that we can enter
778 // into any node and advance from any node respectively.
779 struct DfsTypePredicates {
780 DfsTypePredicates() = default;
DfsTypePredicatestensorflow::grappler::__anon32a4a5c20111::DfsTypePredicates781 DfsTypePredicates(std::function<bool(int)> enter,
782 std::function<bool(int)> advance)
783 : enter(std::move(enter)), advance(std::move(advance)) {}
784
Entertensorflow::grappler::__anon32a4a5c20111::DfsTypePredicates785 static DfsTypePredicates Enter(std::function<bool(int)> enter) {
786 return DfsTypePredicates(std::move(enter), nullptr);
787 }
788
Advancetensorflow::grappler::__anon32a4a5c20111::DfsTypePredicates789 static DfsTypePredicates Advance(std::function<bool(int)> advance) {
790 return DfsTypePredicates(nullptr, std::move(advance));
791 }
792
793 std::function<bool(int)> enter;
794 std::function<bool(int)> advance;
795 };
796
797 struct DfsStackElem {
DfsStackElemtensorflow::grappler::__anon32a4a5c20111::DfsStackElem798 DfsStackElem(int node, bool children_visited, int src)
799 : node(node), children_visited(children_visited), src(src) {}
DfsStackElemtensorflow::grappler::__anon32a4a5c20111::DfsStackElem800 explicit DfsStackElem(int node) : DfsStackElem(node, false, -1) {}
801
802 // Index of the node in the graph ∊ [0, num_nodes).
803 int node;
804 // `True` if visited all the input/output nodes (pushed all input/output nodes
805 // to the stack).
806 bool children_visited;
807 // Index of the node in the graph, from which we entered the `node`.
808 int src;
809 };
810
811 enum class NodeState { kNotVisited, kVisiting, kDone };
812
DfsTypeTraversal(const GraphTypeTopologyView & graph_type_view,const absl::Span<const NodeTypeId * const> from,const TypeTraversalDirection direction,const DfsTypePredicates & predicates,const DfsTypeCallbacks & callbacks)813 void DfsTypeTraversal(const GraphTypeTopologyView& graph_type_view,
814 const absl::Span<const NodeTypeId* const> from,
815 const TypeTraversalDirection direction,
816 const DfsTypePredicates& predicates,
817 const DfsTypeCallbacks& callbacks) {
818 std::vector<DfsStackElem> stack;
819 stack.reserve(from.size());
820
821 for (const NodeTypeId* node : from) {
822 const absl::optional<int> node_idx = graph_type_view.GetNodeIndex(*node);
823 DCHECK(node_idx.has_value())
824 << "Illegal start node: " << node->node->name();
825 if (node_idx.has_value()) {
826 stack.emplace_back(node_idx.value());
827 }
828 }
829
830 absl::flat_hash_map<int, NodeState> node_state;
831 while (!stack.empty()) {
832 DfsStackElem w = stack.back();
833 stack.pop_back();
834
835 NodeState& state = node_state[w.node];
836 if (state == NodeState::kDone) continue;
837
838 // Skip nodes that we should not enter.
839 if (predicates.enter && !predicates.enter(w.node)) {
840 state = NodeState::kDone;
841 continue;
842 }
843
844 // We've processed all the children of this node.
845 if (w.children_visited) {
846 state = NodeState::kDone;
847 if (callbacks.post_order) {
848 callbacks.post_order(w.node);
849 }
850 continue;
851 }
852
853 // Loop detected.
854 if (state == NodeState::kVisiting) {
855 if (callbacks.on_back_edge) {
856 callbacks.on_back_edge(w.src, w.node);
857 }
858 continue;
859 }
860
861 state = NodeState::kVisiting;
862 if (callbacks.pre_order) {
863 callbacks.pre_order(w.node);
864 }
865
866 // Enqueue the node again with the children_visited flag set to true.
867 stack.emplace_back(w.node, true, w.src);
868
869 // Check if we can continue traversal from the current node.
870 if (predicates.advance && !predicates.advance(w.node)) {
871 continue;
872 }
873
874 // Now enqueue the fanin/fanout nodes.
875 if (direction == TypeTraversalDirection::kFollowInputs ||
876 direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
877 for (const int fanin : graph_type_view.GetFanin(w.node)) {
878 stack.emplace_back(fanin, false, w.node);
879 }
880 }
881 if (direction == TypeTraversalDirection::kFollowOutputs ||
882 direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
883 for (const int fanout : graph_type_view.GetFanout(w.node)) {
884 stack.emplace_back(fanout, false, w.node);
885 }
886 }
887 }
888 }
889
AllowedDataTypes(const OpDef::AttrDef & attr_def)890 DataTypeSet AllowedDataTypes(const OpDef::AttrDef& attr_def) {
891 const auto& allowed_types = attr_def.allowed_values().list().type();
892 if (allowed_types.empty()) {
893 return AllTypes();
894 }
895 uint32 dtype_mask = 0;
896 for (int dtype : allowed_types) {
897 dtype_mask |= 1u << dtype;
898 }
899 return DataTypeSet(dtype_mask);
900 }
901
AllowedDataTypes(const OpDef & op_def,const TypeAttrId & t_attr_id)902 DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
903 if (t_attr_id.attr_name.empty()) {
904 return ToSet(t_attr_id.fixed_type);
905 }
906 const OpDef::AttrDef* attr_def = FindAttr(t_attr_id.attr_name, op_def);
907 CHECK(attr_def); // Crash Ok
908 return AllowedDataTypes(*attr_def);
909 }
910
ValidateLists(const gtl::FlatSet<string> & allow_list,const gtl::FlatSet<string> & deny_list,const gtl::FlatSet<string> & infer_list,const gtl::FlatSet<string> & clear_list)911 Status ValidateLists(const gtl::FlatSet<string>& allow_list,
912 const gtl::FlatSet<string>& deny_list,
913 const gtl::FlatSet<string>& infer_list,
914 const gtl::FlatSet<string>& clear_list) {
915 std::vector<gtl::FlatSet<string>> lists{allow_list, deny_list, infer_list,
916 clear_list};
917 std::multiset<string> counts;
918 for (const auto& list : lists) {
919 counts.insert(list.begin(), list.end());
920 }
921 bool duplicates = false;
922 for (const auto& s : counts) {
923 if (counts.count(s) > 1) {
924 duplicates = true;
925 LOG(ERROR) << "Op present in multiple lists: " << s;
926 }
927 }
928 if (duplicates) {
929 return errors::InvalidArgument("Op lists have conflicting entries");
930 } else {
931 return OkStatus();
932 }
933 }
934
HasInputOrOutputRefs(const NodeDef & node)935 bool HasInputOrOutputRefs(const NodeDef& node) {
936 const OpDef* op_def;
937 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
938 if (!status.ok()) {
939 return true;
940 }
941 for (const auto& input : op_def->input_arg()) {
942 if (input.is_ref()) {
943 return true;
944 }
945 }
946 for (const auto& output : op_def->output_arg()) {
947 if (output.is_ref()) {
948 return true;
949 }
950 }
951 return false;
952 }
953
954 // See TF issue 25977 for no-FP16 on SCEWL
CanForceFP16(const NodeDef & node)955 bool CanForceFP16(const NodeDef& node) {
956 return node.op() != "Const" && node.op() != "SoftmaxCrossEntropyWithLogits" &&
957 !IsStateful(node) && !HasInputOrOutputRefs(node);
958 }
959
GetCudaVersion(const std::unordered_map<string,DeviceProperties> & devices)960 int GetCudaVersion(
961 const std::unordered_map<string, DeviceProperties>& devices) {
962 for (const auto& device : devices) {
963 const DeviceProperties& device_properties = device.second;
964 if (device_properties.type() == "GPU") {
965 const auto& device_env = device_properties.environment();
966 auto it = device_env.find("cuda");
967 if (it != device_env.end()) {
968 string cuda_version_str = it->second;
969 return std::stoi(cuda_version_str);
970 }
971 }
972 }
973 return 0;
974 }
975
GetCudnnVersion(const std::unordered_map<string,DeviceProperties> & devices)976 int GetCudnnVersion(
977 const std::unordered_map<string, DeviceProperties>& devices) {
978 for (const auto& device : devices) {
979 const DeviceProperties& device_properties = device.second;
980 if (device_properties.type() == "GPU") {
981 const auto& device_env = device_properties.environment();
982 auto it = device_env.find("cudnn");
983 if (it != device_env.end()) {
984 string cudnn_version_str = it->second;
985 return std::stoi(cudnn_version_str);
986 }
987 }
988 }
989 return 0;
990 }
991
GetDevices(Cluster * cluster)992 std::unordered_map<string, DeviceProperties> GetDevices(Cluster* cluster) {
993 if (!ShouldSimulateGpu()) {
994 return cluster->GetDevices();
995 }
996
997 bool has_gpu = false;
998 for (const auto& device : cluster->GetDevices()) {
999 const DeviceProperties& device_properties = device.second;
1000 if (device_properties.type() == "GPU") {
1001 has_gpu = true;
1002 break;
1003 }
1004 }
1005
1006 if (has_gpu) {
1007 return cluster->GetDevices();
1008 }
1009
1010 std::unordered_map<string, DeviceProperties> devices(cluster->GetDevices());
1011 DeviceProperties gpu_device_properies;
1012 gpu_device_properies.set_type("GPU");
1013 #if GOOGLE_CUDA
1014 gpu_device_properies.set_vendor("NVIDIA");
1015 gpu_device_properies.mutable_environment()->insert({"architecture", "8.0"});
1016 gpu_device_properies.mutable_environment()->insert({"cuda", "11050"});
1017 gpu_device_properies.mutable_environment()->insert({"cudnn", "8302"});
1018 #elif TENSORFLOW_USE_ROCM
1019 gpu_device_properies.set_vendor("Advanced Micro Devices, Inc");
1020 gpu_device_properies.mutable_environment()->insert(
1021 {"architecture", "gfx908"});
1022 #endif
1023 devices.emplace(std::make_pair("/job:localhost/replica:0/task:0/device:GPU:0",
1024 gpu_device_properies));
1025 return devices;
1026 }
1027
1028 class AutoMixedPrecisionImpl {
1029 public:
1030 // CastType indicates the type of inserted Cast op
1031 // FP16: cast to float16
1032 // FP32: cast to float32
1033 // AUTO: cast to a data type that matches the required data type at fanouts
1034 enum class CastType { FP16, FP32, AUTO };
AutoMixedPrecisionImpl(Cluster * cluster,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,string id,AutoMixedPrecisionMode mode)1035 AutoMixedPrecisionImpl(Cluster* cluster,
1036 const std::unordered_set<string>& nodes_to_preserve,
1037 GraphDef* graph, string id,
1038 AutoMixedPrecisionMode mode)
1039 : devices_(GetDevices(cluster)),
1040 virtual_placer_(devices_),
1041 nodes_to_preserve_(nodes_to_preserve),
1042 graph_(graph),
1043 function_library_(OpRegistry::Global(), graph->library()),
1044 id_(id),
1045 graph_view_(graph),
1046 cuda_version_(GetCudaVersion(devices_)),
1047 cudnn_version_(GetCudnnVersion(devices_)),
1048 num_nonvar_casts_to_f16_(0),
1049 mode_(mode),
1050 target_dtype_((mode_ == AutoMixedPrecisionMode::CUDA ||
1051 mode_ == AutoMixedPrecisionMode::CPU)
1052 ? DT_HALF
1053 : DT_BFLOAT16) {}
1054
1055 Status Optimize();
1056
1057 private:
1058 typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
1059
get_mixed_precision_lists() const1060 std::unique_ptr<AutoMixedPrecisionLists> get_mixed_precision_lists() const {
1061 switch (mode_) {
1062 case AutoMixedPrecisionMode::CUDA:
1063 return std::make_unique<AutoMixedPrecisionListsCuda>(cuda_version_,
1064 cudnn_version_);
1065 case AutoMixedPrecisionMode::BF16:
1066 return std::make_unique<AutoMixedPrecisionListsMkl>();
1067 case AutoMixedPrecisionMode::CPU:
1068 // Note: this is not a typo here. AutoMixedPrecisionListsCuda is used
1069 // intentionally to make CPU and GPU have the same fp16 ops.
1070 return std::make_unique<AutoMixedPrecisionListsCuda>(
1071 /*cuda_version=*/10000, // Hardcode cuda and cudnn version so
1072 /*cudnn_version=*/8000); // CPU emulates the same ops on GPU.
1073 }
1074 }
1075 Status PrintDebugLogs(bool preop, size_t timestamp);
1076 void LogSkippedNode(const NodeDef& node, const string& device_type) const;
1077 bool MustPreserve(const NodeDef& node) const;
1078 bool IsOnDevice(const NodeDef& node, const string& device_type) const;
1079 bool IsOnSuitableGPUArch(const NodeDef& node) const;
1080 bool ShouldProcess(const NodeDef& node) const;
1081 bool NodeHasF16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
1082 bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
1083 void ConvertBatchNormOpsToV2();
1084 bool SupportsF16(const NodeTypeId& node_type) const;
1085 bool SupportsF16DataType(const NodeTypeId& node_type) const;
1086 bool IsQuantized(const NodeTypeId& node_type) const;
1087 const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
1088 bool IsSourceOrSinkOp(const string& op) const;
1089 void FindFloat32TensorListOpClustersAndDenylistUnsafe(
1090 std::vector<absl::flat_hash_set<const NodeDef*>>* clusters,
1091 absl::flat_hash_set<int>* deny_set) const;
1092 void FindTensorListImplicitFloat32Edges(
1093 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1094 std::vector<NodeTypeIdEdge>* implicit_fp32_edges) const;
1095 void AddAllowlistOps(absl::flat_hash_set<int>* allow_set) const;
1096 void RemoveAllowsetWithFp32(absl::flat_hash_set<int>* allow_set) const;
1097 void PropagateDenyFwdThroughClearAndInfer(
1098 absl::flat_hash_set<int>* deny_set) const;
1099 void ForceColorMatchBetweenTensorListOps(
1100 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1101 absl::flat_hash_set<int>* allow_set,
1102 absl::flat_hash_set<int>* deny_set) const;
1103 void AddClearAndInferToAllowIfBetweenAllow(
1104 const absl::flat_hash_set<int>& deny_set,
1105 absl::flat_hash_set<int>* allow_set) const;
1106 void AddInferToAllowIfFollowAllow(const absl::flat_hash_set<int>& deny_set,
1107 absl::flat_hash_set<int>* allow_set) const;
1108 void PropagateAllowThroughClear(const absl::flat_hash_set<int>& deny_set,
1109 absl::flat_hash_set<int>* allow_set) const;
1110 Status ForceColorMatchOnRecurrentEdges(
1111 absl::flat_hash_set<int>* allow_set) const;
1112 void MakeCastsAllowIfAllOutputsAllow(
1113 absl::flat_hash_set<int>* allow_set) const;
1114 NodeDef BuildCastNode(const MutableGraphView::OutputPort& src,
1115 const MutableGraphView::InputPort& dst, bool to_f16,
1116 const string& device) const;
1117 StatusOr<NodeDef*> InsertCastNodeAtFanout(
1118 const absl::flat_hash_set<int>& allow_set, const bool src_is_allow,
1119 const CastType& cast_type, MutableGraphView::OutputPort& src);
1120
1121 StatusOr<DataType> GetCastToType(const NodeDef* node) const;
1122 void CollectOutputPorts(
1123 const TypeAttrId& type_attr, NodeDef* node,
1124 std::vector<MutableGraphView::OutputPort>& output_ports) const;
1125 Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& allow_set);
1126
1127 std::unordered_map<string, DeviceProperties> devices_;
1128 VirtualPlacer virtual_placer_;
1129 std::unordered_set<string> nodes_to_preserve_;
1130 GraphDef* graph_;
1131 FunctionLibraryDefinition function_library_;
1132 string id_;
1133 MutableGraphView graph_view_;
1134 int cuda_version_;
1135 int cudnn_version_;
1136 int num_nonvar_casts_to_f16_;
1137 NodeTypeAttrMap node_type_map_;
1138 GraphTypeTopologyView graph_type_view_;
1139 bool force_all_fp16_;
1140 bool treat_infer_as_deny_;
1141 AutoMixedPrecisionMode mode_;
1142 gtl::FlatSet<string> f16_allowlist_;
1143 gtl::FlatSet<string> f16_denylist_;
1144 gtl::FlatSet<string> f16_inferlist_;
1145 gtl::FlatSet<string> f16_clearlist_;
1146 absl::flat_hash_set<const NodeDef*> should_process_nodes_;
1147 DataType target_dtype_; // Either DT_HALF or DT_BFLOAT16
1148 };
1149
BuildCastNode(const MutableGraphView::OutputPort & src,const MutableGraphView::InputPort & dst,bool to_f16,const string & device) const1150 NodeDef AutoMixedPrecisionImpl::BuildCastNode(
1151 const MutableGraphView::OutputPort& src,
1152 const MutableGraphView::InputPort& dst, bool to_f16,
1153 const string& device) const {
1154 DataType src_type = to_f16 ? DT_FLOAT : target_dtype_;
1155 DataType dst_type = to_f16 ? target_dtype_ : DT_FLOAT;
1156 const char* cast_string = !to_f16 ? kCastToFp32
1157 : target_dtype_ == DT_HALF ? kCastToFp16
1158 : kCastToBf16;
1159 string name =
1160 strings::StrCat(src.node->name(), "-", src.port_id, "-", dst.node->name(),
1161 "-", dst.port_id, "-", cast_string, "-", kSuffix);
1162 NodeDef node;
1163 node.set_name(name);
1164 node.set_op("Cast");
1165 node.set_device(device);
1166 node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
1167 (*node.mutable_attr())["SrcT"].set_type(src_type);
1168 (*node.mutable_attr())["DstT"].set_type(dst_type);
1169 (*node.mutable_attr())["Truncate"].set_b(false);
1170 return node;
1171 }
1172
NodeHasF16KernelForTypeAttr(const NodeDef & node,TypeAttrId taid) const1173 bool AutoMixedPrecisionImpl::NodeHasF16KernelForTypeAttr(
1174 const NodeDef& node, TypeAttrId taid) const {
1175 NodeDef node_copy(node);
1176 if (node.device().empty()) {
1177 string device_name = virtual_placer_.get_canonical_device_name(node);
1178 node_copy.set_device(device_name);
1179 }
1180 if (!SetDataType(&node_copy, taid, target_dtype_)) {
1181 return false;
1182 }
1183 return IsKernelRegisteredForNode(node_copy).ok();
1184 }
1185
PrintDebugLogs(bool preop,size_t timestamp)1186 Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
1187 string prepend_path;
1188 TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1189 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LOG_PATH", "", &prepend_path));
1190 if (prepend_path.empty()) return OkStatus();
1191
1192 string suffix =
1193 strings::StrCat("_", preop ? "preop" : kSuffix, "_", id_, "_", timestamp);
1194
1195 string fname =
1196 io::JoinPath(prepend_path, strings::StrCat("graphdef", suffix, ".pb"));
1197 std::fstream f;
1198 f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
1199 f << graph_->SerializeAsString();
1200 f.close();
1201 LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1202 << " graph as binary to " << fname;
1203
1204 fname = io::JoinPath(prepend_path,
1205 strings::StrCat("graphdef", suffix, ".pb.txt"));
1206 f.open(fname.c_str(), std::fstream::out);
1207 f << graph_->DebugString();
1208 f.close();
1209 LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1210 << " graph as text to " << fname;
1211
1212 if (!preop) {
1213 fname = io::JoinPath(prepend_path,
1214 strings::StrCat("paintbuckets", suffix, ".txt"));
1215 f.open(fname.c_str(), std::fstream::out);
1216 std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1217 get_mixed_precision_lists();
1218 f << "AllowList:\n";
1219 for (const auto& x : mp_lists->AllowList()) {
1220 f << x << "\n";
1221 }
1222 f << "\nDenyList:\n";
1223 for (const auto& x : mp_lists->DenyList()) {
1224 f << x << "\n";
1225 }
1226 f << "\nInferList:\n";
1227 for (const auto& x : mp_lists->InferList()) {
1228 f << x << "\n";
1229 }
1230 f << "\nClearList:\n";
1231 for (const auto& x : mp_lists->ClearList()) {
1232 f << x << "\n";
1233 }
1234 f.close();
1235 LOG(INFO) << "Saved paint bucket info to " << fname;
1236 }
1237 return OkStatus();
1238 }
1239
LogSkippedNode(const NodeDef & node,const string & device_type) const1240 void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node,
1241 const string& device_type) const {
1242 VLOG(2) << "Skipping " << node.op() << " node " << node.name()
1243 << " because it "
1244 << (MustPreserve(node)
1245 ? "must be preserved"
1246 : absl::StrFormat(
1247 "is not on the %s, or the %s arch is not suitable",
1248 device_type, device_type));
1249 }
1250
MustPreserve(const NodeDef & node) const1251 bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
1252 return nodes_to_preserve_.count(node.name());
1253 }
1254
IsOnDevice(const NodeDef & node,const string & device_type) const1255 bool AutoMixedPrecisionImpl::IsOnDevice(const NodeDef& node,
1256 const string& device_type) const {
1257 string device_name;
1258 if (node.device().empty()) {
1259 device_name = virtual_placer_.get_canonical_device_name(node);
1260 } else {
1261 device_name = node.device();
1262 }
1263 string device;
1264 string not_used;
1265 if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) &&
1266 absl::StrContains(absl::AsciiStrToLower(device),
1267 absl::AsciiStrToLower(device_type))) {
1268 return true;
1269 }
1270 return false;
1271 }
1272
IsOnSuitableGPUArch(const NodeDef & node) const1273 bool AutoMixedPrecisionImpl::IsOnSuitableGPUArch(const NodeDef& node) const {
1274 return HasFastFP16Support(virtual_placer_.get_device(node));
1275 }
1276
ShouldProcess(const NodeDef & node) const1277 bool AutoMixedPrecisionImpl::ShouldProcess(const NodeDef& node) const {
1278 return should_process_nodes_.count(&node);
1279 }
1280
IsFloat32(const NodeTypeId & node_type)1281 bool IsFloat32(const NodeTypeId& node_type) {
1282 return GetDataType(*node_type.node, node_type.type_attr) ==
1283 DataType::DT_FLOAT;
1284 }
1285
IsTensorListOp(const string & op)1286 bool IsTensorListOp(const string& op) {
1287 return absl::StrContains(op, "TensorList");
1288 }
1289
IsTensorListReaderOp(const string & op)1290 bool IsTensorListReaderOp(const string& op) {
1291 static const gtl::FlatSet<string> tensor_list_reader_ops = {
1292 "TensorListConcat", "TensorListConcatV2", "TensorListGather",
1293 "TensorListGetItem", "TensorListPopBack", "TensorListStack"};
1294 return tensor_list_reader_ops.count(op);
1295 }
1296
IsTensorListWriterOp(const string & op)1297 bool IsTensorListWriterOp(const string& op) {
1298 static const gtl::FlatSet<string> tensor_list_writer_ops = {
1299 "TensorListFromTensor", "TensorListPushBack",
1300 "TensorListPushBackBatch", "TensorListScatter",
1301 "TensorListScatterV2", "TensorListScatterIntoExistingList",
1302 "TensorListSetItem", "TensorListSplit"};
1303 return tensor_list_writer_ops.count(op);
1304 }
1305
SupportsF16(const NodeTypeId & node_type) const1306 bool AutoMixedPrecisionImpl::SupportsF16(const NodeTypeId& node_type) const {
1307 const OpDef* op_def;
1308 Status status =
1309 OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1310 if (!status.ok()) return false;
1311 return AllowedDataTypes(*op_def, node_type.type_attr)
1312 .Contains(target_dtype_) &&
1313 NodeHasF16KernelForTypeAttr(*node_type.node, node_type.type_attr);
1314 }
1315
SupportsF16DataType(const NodeTypeId & node_type) const1316 bool AutoMixedPrecisionImpl::SupportsF16DataType(
1317 const NodeTypeId& node_type) const {
1318 const OpDef* op_def;
1319 Status status =
1320 OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1321 if (!status.ok()) return false;
1322 return AllowedDataTypes(*op_def, node_type.type_attr).Contains(target_dtype_);
1323 }
1324
IsQuantized(const NodeTypeId & node_type) const1325 bool AutoMixedPrecisionImpl::IsQuantized(const NodeTypeId& node_type) const {
1326 for (const TypeAttrId& type_attr :
1327 node_type_map_.GetTypeAttrs(*node_type.node)) {
1328 if (DataTypeIsQuantized(GetDataType(*node_type.node, type_attr))) {
1329 return true;
1330 }
1331 }
1332 return false;
1333 }
1334
1335 // TODO(mconley): Make this change the node's name (to aid debugging). Need to
1336 // make sure that doing this won't break anything.
ConvertBatchNormOpsToV2()1337 void AutoMixedPrecisionImpl::ConvertBatchNormOpsToV2() {
1338 for (int node_idx = 0; node_idx < graph_->node_size(); ++node_idx) {
1339 NodeDef* node = graph_->mutable_node(node_idx);
1340 if (!ShouldProcess(*node)) continue;
1341 bool changed = false;
1342 if (node->op() == "FusedBatchNorm") {
1343 VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1344 << " to FusedBatchNormV2";
1345 node->set_op("FusedBatchNormV2");
1346 changed = true;
1347 } else if (node->op() == "FusedBatchNormGrad") {
1348 VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1349 << " to FusedBatchNormGradV2";
1350 node->set_op("FusedBatchNormGradV2");
1351 changed = true;
1352 }
1353 if (changed) {
1354 (*node->mutable_attr())["U"].set_type(DT_FLOAT);
1355 }
1356 }
1357 }
1358
1359 // A helper function to decide whether to ignore the effect on performance when
1360 // rewriting the graph. This can be useful for testing the numerical effects of
1361 // reduced precision on systems that have poor mixed precision performance.
ShouldIgnorePerformance()1362 bool ShouldIgnorePerformance() {
1363 static bool is_enabled = [] {
1364 bool ret = false;
1365 TF_CHECK_OK(ReadBoolFromEnvVar(
1366 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE",
1367 /*default_val=*/false, &ret));
1368 return ret;
1369 }();
1370 return is_enabled;
1371 }
1372
Optimize()1373 Status AutoMixedPrecisionImpl::Optimize() {
1374 string optimization_level;
1375 TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1376 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
1377 optimization_level = absl::AsciiStrToUpper(optimization_level);
1378 force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
1379 if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::BF16) {
1380 // Many ops do not support bfloat16 on the CPU so we disallowing forcing to
1381 // bfloat16.
1382 return errors::InvalidArgument(
1383 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to "
1384 "UNSAFE_FORCE_ALL when oneDNN is used");
1385 }
1386
1387 treat_infer_as_deny_ = optimization_level == "TREAT_INFER_AS_DENY";
1388 VLOG(2) << "Optimization Level: " << optimization_level;
1389
1390 std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1391 get_mixed_precision_lists();
1392 f16_allowlist_ = mp_lists->AllowList();
1393 f16_denylist_ = mp_lists->DenyList();
1394
1395 if (treat_infer_as_deny_) {
1396 for (const auto& op : mp_lists->InferList()) {
1397 f16_denylist_.insert(op);
1398 }
1399 } else {
1400 f16_inferlist_ = mp_lists->InferList();
1401 }
1402
1403 f16_clearlist_ = mp_lists->ClearList();
1404 TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_denylist_,
1405 f16_inferlist_, f16_clearlist_));
1406
1407 size_t timestamp = Env::Default()->NowMicros() / 1000;
1408 TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
1409
1410 VLOG(2) << "Identifying nodes that should be processed";
1411 for (const NodeDef& node : graph_->node()) {
1412 bool should_process;
1413 string device_type;
1414 switch (mode_) {
1415 case AutoMixedPrecisionMode::CUDA:
1416 device_type = DEVICE_GPU;
1417 should_process =
1418 !MustPreserve(node) && IsOnDevice(node, device_type) &&
1419 (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
1420 break;
1421 case AutoMixedPrecisionMode::BF16:
1422 case AutoMixedPrecisionMode::CPU:
1423 device_type = DEVICE_CPU;
1424 should_process = !MustPreserve(node) && IsOnDevice(node, device_type);
1425 break;
1426 }
1427 if (should_process) {
1428 should_process_nodes_.insert(&node);
1429 } else {
1430 LogSkippedNode(node, device_type);
1431 }
1432 }
1433
1434 VLOG(2) << "Converting FusedBatchNorm* ops to V2";
1435 ConvertBatchNormOpsToV2();
1436
1437 VLOG(2) << "Building node type map for graph";
1438 TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_));
1439
1440 VLOG(2) << "Constructing graph type attribute topology view";
1441 TF_RETURN_IF_ERROR(
1442 graph_type_view_.InitializeFromGraph(*graph_, node_type_map_));
1443
1444 absl::flat_hash_set<int> deny_set;
1445
1446 std::vector<absl::flat_hash_set<const NodeDef*>> tensor_list_clusters;
1447 FindFloat32TensorListOpClustersAndDenylistUnsafe(&tensor_list_clusters,
1448 &deny_set);
1449 std::vector<NodeTypeIdEdge> ephemeral_edges;
1450 for (const auto& cluster : tensor_list_clusters) {
1451 VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
1452 for (const NodeDef* node : cluster) {
1453 VLOG(2) << " Cluster member: " << node->op() << " node " << node->name();
1454 }
1455 FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
1456 }
1457 TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));
1458
1459 // The goal here is to change performance-critical ops to fp16 or bf16, and to
1460 // do so with the minimal number of casts, subject to the constraint that the
1461 // model's convergence is not affected. This is achieved by first identifying
1462 // which nodes should be changed to f16 and then inserting casts at the
1463 // boundaries between f16/non-f16 nodes.
1464
1465 // The algorithm for deciding which nodes to change to f16 is as follows:
1466 // 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set.
1467 // This is done under the assumption that allowlist ops are always
1468 // numerically-safe in f16 and that they are the most important ops for
1469 // improving performance.
1470 // 2) Add nodes to the deny_set iff they are numerically-dangerous (aka
1471 // "denylist" ops) or they are on a forward path from a denylist node to
1472 // a deny/infer node (including the node at the end of the path) through
1473 // non-numerically-dangerous ops (aka "inferlist" and "clearlist" ops).
1474 // This is done to prevent numerically-dangerous ops and their downstream
1475 // effects from being changed to f16, which would risk breaking the
1476 // numerical accuracy of the model.
1477 // 3) For all remaining nodes that are not considered dangerous (inferlist
1478 // and clearlist ops), find those that are between (i.e., both upstream
1479 // and downstream of) allow nodes, and add them to the allow_set.
1480 // This is done to avoid unnecessary casts between allowlist ops.
1481 // 4) For the remaining inferlist nodes, add them to the allow_set if they
1482 // are immediate downstream of allow_set node.
1483 // 5) For all remaining clearlist nodes, add them to the allow_set if they are
1484 // connected to a node in the allow_set via other clearlist nodes.
1485 // This is done to increase the number of ops in the allow_set without
1486 // affecting numerical stability.
1487
1488 absl::flat_hash_set<int> allow_set;
1489 VLOG(2) << "Beginning pass 1 to add allowlist ops";
1490 AddAllowlistOps(&allow_set);
1491 VLOG(2) << "Finished pass 1";
1492
1493 if (allow_set.empty()) {
1494 LOG(INFO) << "No allowlist ops found, nothing to do";
1495 return OkStatus();
1496 }
1497
1498 VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops "
1499 "through clear/inferlist ops";
1500 PropagateDenyFwdThroughClearAndInfer(&deny_set);
1501 VLOG(2) << "Finished pass 2";
1502
1503 VLOG(2) << "Forcing color match between data structure ops";
1504 for (const auto& cluster : tensor_list_clusters) {
1505 ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1506 }
1507
1508 VLOG(2) << "Beginning pass 3 to set clear and infer nodes to allow if they "
1509 "are between allow ops";
1510 AddClearAndInferToAllowIfBetweenAllow(deny_set, &allow_set);
1511 VLOG(2) << "Finished pass 3";
1512
1513 VLOG(2) << "Beginning pass 4 to add infer list ops to allow if they "
1514 "directly follow allow nodes";
1515 AddInferToAllowIfFollowAllow(deny_set, &allow_set);
1516 VLOG(2) << "Finished pass 4";
1517
1518 VLOG(2) << "Beginning pass 5 to propagate allow from allow nodes through "
1519 "clearlist ops";
1520 PropagateAllowThroughClear(deny_set, &allow_set);
1521 VLOG(2) << "Finished pass 5";
1522
1523 VLOG(2) << "Beginning pass 6 to remove some nodes which could not be changed "
1524 "to F16"
1525 "from allow set";
1526 RemoveAllowsetWithFp32(&allow_set);
1527 VLOG(2) << "Finished pass 6";
1528
1529 VLOG(2) << "Forcing color match between data structure ops";
1530 for (const auto& cluster : tensor_list_clusters) {
1531 ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1532 }
1533
1534 VLOG(2) << "Forcing color match on loop edges";
1535 TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set));
1536
1537 VLOG(2) << "Finding existing casts that can be made allow";
1538 MakeCastsAllowIfAllOutputsAllow(&allow_set);
1539
1540 VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
1541 "ops at paint boundaries";
1542 TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set));
1543 VLOG(2) << "Finished final pass";
1544
1545 TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
1546
1547 return OkStatus();
1548 }
1549
1550 // If node is a Tensor List op with a float32 data type attribute then this
1551 // returns a pointer to the NodeTypeId representing that type attribute. In
1552 // all other cases this returns nullptr.
GetTensorListFloat32NodeTypeId(const NodeDef & node) const1553 const NodeTypeId* AutoMixedPrecisionImpl::GetTensorListFloat32NodeTypeId(
1554 const NodeDef& node) const {
1555 if (!IsTensorListOp(node.op())) return nullptr;
1556 for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(node)) {
1557 const NodeTypeId* node_type =
1558 graph_type_view_.GetNode(node.name(), type_attr);
1559 // This assumes that the float32 data type on a Tensor List op is always a
1560 // non-fixed type attribute containing a single type, and that this type
1561 // attribute represents the dtype of the values in the list.
1562 // TODO(benbarsdell): A new Tensor List op could theoretically break these
1563 // assumptions.
1564 if (node_type && node_type->type_attr.fixed_type == DT_INVALID &&
1565 node_type->type_attr.type_index == TypeAttrId::kSingleType &&
1566 IsFloat32(*node_type)) {
1567 return node_type;
1568 }
1569 }
1570 return nullptr;
1571 }
1572
IsSourceOrSinkOp(const string & op) const1573 bool AutoMixedPrecisionImpl::IsSourceOrSinkOp(const string& op) const {
1574 const gtl::FlatSet<string> source_and_sink_ops = {
1575 "_Arg",
1576 "_Retval",
1577 "OptionalFromValue",
1578 "OptionalGetValue",
1579 "PartitionedCall",
1580 "Placeholder",
1581 "StatefulPartitionedCall",
1582 };
1583 return source_and_sink_ops.count(op) || function_library_.Find(op);
1584 }
1585
1586 // Finds all clusters of float32 Tensor List nodes that are connected via their
1587 // handle edges. Unsafe clusters (those with unprocessable nodes, or with edges
1588 // that cross untraversable boundaries via _Arg, _Ret, PartitionedCall etc.
1589 // nodes) are added to deny_set. The caller should paint all nodes in a cluster
1590 // the same color, as they may all refer to the same Tensor List.
FindFloat32TensorListOpClustersAndDenylistUnsafe(std::vector<absl::flat_hash_set<const NodeDef * >> * tensor_list_clusters,absl::flat_hash_set<int> * deny_set) const1591 void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndDenylistUnsafe(
1592 std::vector<absl::flat_hash_set<const NodeDef*>>* tensor_list_clusters,
1593 absl::flat_hash_set<int>* deny_set) const {
1594 absl::flat_hash_set<const NodeDef*> tensor_list_prop_set;
1595 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1596 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1597 if (!ShouldProcess(*root.node) ||
1598 root.type_attr.fixed_type != DataType::DT_VARIANT ||
1599 !GetTensorListFloat32NodeTypeId(*root.node) ||
1600 tensor_list_prop_set.count(root.node)) {
1601 continue;
1602 }
1603 const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1604 const absl::optional<int> maybe_root_fp32_idx =
1605 graph_type_view_.GetNodeIndex(*root_fp32);
1606 DCHECK(maybe_root_fp32_idx.has_value())
1607 << "Type attribute " << root_fp32->type_attr.DebugString()
1608 << " of node " << root.node->name() << " not found in graph view";
1609 int root_fp32_idx = maybe_root_fp32_idx.value();
1610 // Traverse Tensor List handle edges (DT_VARIANT) to find cluster of all
1611 // connected Tensor List nodes.
1612 absl::flat_hash_set<const NodeDef*> cluster({root.node});
1613 DfsTypeTraversal(graph_type_view_, {&root},
1614 TypeTraversalDirection::kFollowInputsAndOutputs,
1615 DfsTypePredicates::Enter([&](int idx) -> bool {
1616 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1617 return !tensor_list_prop_set.count(item.node);
1618 }),
1619 DfsTypeCallbacks::PreOrder([&](int idx) {
1620 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1621 const NodeDef* node = item.node;
1622 if (GetTensorListFloat32NodeTypeId(*node)) {
1623 cluster.insert(node);
1624 if (!ShouldProcess(*node)) {
1625 // The cluster contains an un-processable node.
1626 deny_set->insert(root_fp32_idx);
1627 }
1628 // TODO(benbarsdell): In a theoretical pathological
1629 // case of a Tensor List of Tensor List handles, the
1630 // Tensor List itself would need to be treated as a
1631 // sink.
1632 } else if (IsSourceOrSinkOp(node->op())) {
1633 // The cluster crosses an untraversable boundary.
1634 deny_set->insert(root_fp32_idx);
1635 }
1636 }));
1637 tensor_list_clusters->push_back(cluster);
1638 }
1639 }
1640
1641 // Finds all writer -> reader pairs in the given set that are connected via
1642 // their handles, and adds corresponding float32 edges to *implicit_fp32_edges.
FindTensorListImplicitFloat32Edges(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,std::vector<NodeTypeIdEdge> * implicit_fp32_edges) const1643 void AutoMixedPrecisionImpl::FindTensorListImplicitFloat32Edges(
1644 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1645 std::vector<NodeTypeIdEdge>* implicit_fp32_edges) const {
1646 for (const NodeDef* root_node : tensor_list_nodes) {
1647 if (!IsTensorListReaderOp(root_node->op())) continue;
1648 NodeTypeId root(root_node, TypeAttrId(DataType::DT_VARIANT));
1649 const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1650 CHECK(root_fp32) << "No float32 type attribute found for " // Crash OK
1651 << root.node->op() << " node " << root.node->name();
1652 // Search backwards through handle edges (DT_VARIANT) for all writer ops,
1653 // adding direct implicit edges between them and the reader.
1654 DfsTypeTraversal(
1655 graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1656 DfsTypePredicates::Enter([&](int idx) -> bool {
1657 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1658 return ShouldProcess(*item.node);
1659 }),
1660 DfsTypeCallbacks::PreOrder([&](int idx) {
1661 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1662 if (IsTensorListWriterOp(item.node->op())) {
1663 const NodeTypeId* item_fp32 =
1664 GetTensorListFloat32NodeTypeId(*item.node);
1665 CHECK(item_fp32) // Crash OK
1666 << "No float32 type attribute found for " << item.node->op()
1667 << " node " << item.node->name();
1668 VLOG(2) << "Adding ephemeral float32 edge from "
1669 << item_fp32->node->op() << " node "
1670 << item_fp32->node->name() << " to "
1671 << root_fp32->node->op() << " node "
1672 << root_fp32->node->name();
1673 implicit_fp32_edges->emplace_back(*item_fp32, *root_fp32);
1674 }
1675 }));
1676 }
1677 }
1678
AddAllowlistOps(absl::flat_hash_set<int> * allow_set) const1679 void AutoMixedPrecisionImpl::AddAllowlistOps(
1680 absl::flat_hash_set<int>* allow_set) const {
1681 // Add allowlisted ops to allow_set.
1682 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1683 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1684 if (!ShouldProcess(*root.node)) continue;
1685 bool force_allow = force_all_fp16_ && CanForceFP16(*root.node);
1686 if (f16_allowlist_.count(root.node->op()) || force_allow) {
1687 bool inserted = allow_set->insert(root_idx).second;
1688 if (VLOG_IS_ON(2) && inserted) {
1689 VLOG(2) << "Painting type " << root.type_attr.DebugString()
1690 << " of node " << root.node->name() << " ALLOW because its op "
1691 << root.node->op() << " is on the allowlist";
1692 }
1693 }
1694 }
1695 }
1696
1697 // Adds nodes to deny_set iff they are on the denylist or they are on a
1698 // forward path from a denylist node to a deny/infer node (including the node
1699 // at the end of the path) through clear and infer nodes.
1700 // E.g., deny -> infer -> clear -> infer -> clear -> allow -> infer
1701 // becomes: deny -> deny -> deny -> deny -> clear -> allow -> infer.
PropagateDenyFwdThroughClearAndInfer(absl::flat_hash_set<int> * deny_set) const1702 void AutoMixedPrecisionImpl::PropagateDenyFwdThroughClearAndInfer(
1703 absl::flat_hash_set<int>* deny_set) const {
1704 if (force_all_fp16_) return;
1705
1706 // Find clear nodes that are upstream of deny or infer.
1707 absl::flat_hash_set<int> upstream_of_deny_or_infer_set;
1708 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1709 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1710 if (!(f16_denylist_.count(root.node->op()) ||
1711 f16_inferlist_.count(root.node->op()))) {
1712 continue;
1713 }
1714 DfsTypeTraversal(graph_type_view_, {&root},
1715 TypeTraversalDirection::kFollowInputs,
1716 DfsTypePredicates::Enter([&](int idx) -> bool {
1717 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1718 return idx == root_idx ||
1719 (!upstream_of_deny_or_infer_set.count(idx) &&
1720 f16_clearlist_.count(item.node->op()));
1721 }),
1722 DfsTypeCallbacks::PreOrder([&](int idx) {
1723 upstream_of_deny_or_infer_set.insert(idx);
1724 }));
1725 }
1726
1727 // Propagate deny forward through nodes in upstream_of_deny_or_infer_set.
1728 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1729 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1730 if (deny_set->count(root_idx) || !f16_denylist_.count(root.node->op())) {
1731 continue;
1732 }
1733 DfsTypeTraversal(
1734 graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1735 DfsTypePredicates::Enter([&](int idx) -> bool {
1736 return idx == root_idx || (!deny_set->count(idx) &&
1737 upstream_of_deny_or_infer_set.count(idx));
1738 }),
1739 DfsTypeCallbacks::PreOrder([&](int idx) {
1740 bool inserted = deny_set->insert(idx).second;
1741 if (VLOG_IS_ON(2) && inserted) {
1742 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1743 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1744 << " of " << item.node->op() << " node "
1745 << item.node->name() << " DENY";
1746 }
1747 }));
1748 }
1749 }
1750
AddClearAndInferToAllowIfBetweenAllow(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1751 void AutoMixedPrecisionImpl::AddClearAndInferToAllowIfBetweenAllow(
1752 const absl::flat_hash_set<int>& deny_set,
1753 absl::flat_hash_set<int>* allow_set) const {
1754 // Find clear/inferlist ops that are downstream of allow ops.
1755 absl::flat_hash_set<int> downstream_of_allow_set;
1756 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1757 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1758 if (!ShouldProcess(*root.node) || !f16_allowlist_.count(root.node->op())) {
1759 continue;
1760 }
1761 DfsTypeTraversal(
1762 graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1763 DfsTypePredicates::Enter([&](int idx) -> bool {
1764 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1765 return idx == root_idx ||
1766 (!downstream_of_allow_set.count(idx) &&
1767 !f16_allowlist_.count(item.node->op()) &&
1768 !deny_set.count(idx) && ShouldProcess(*item.node) &&
1769 // TODO(benbarsdell): Consider allowing propagation through
1770 // ops that are already float16 in order to reduce the number
1771 // of casts.
1772 IsFloat32(item) && SupportsF16(item) &&
1773 (f16_clearlist_.count(item.node->op()) ||
1774 f16_inferlist_.count(item.node->op())));
1775 }),
1776 DfsTypeCallbacks::PreOrder(
1777 [&](int idx) { downstream_of_allow_set.insert(idx); }));
1778 }
1779
1780 // Set nodes that are both downstream and upstream of allow ops to allow.
1781 absl::flat_hash_set<int> upstream_of_allow_set;
1782 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1783 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1784 if (!ShouldProcess(*root.node) || upstream_of_allow_set.count(root_idx) ||
1785 !f16_allowlist_.count(root.node->op())) {
1786 continue;
1787 }
1788 DfsTypeTraversal(
1789 graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1790 DfsTypePredicates::Enter([&](int idx) -> bool {
1791 return idx == root_idx || (!upstream_of_allow_set.count(idx) &&
1792 downstream_of_allow_set.count(idx));
1793 }),
1794 DfsTypeCallbacks::PreOrder([&](int idx) {
1795 upstream_of_allow_set.insert(idx);
1796 bool inserted = allow_set->insert(idx).second;
1797 if (VLOG_IS_ON(2) && inserted) {
1798 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1799 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1800 << " of " << item.node->op() << " node "
1801 << item.node->name() << " ALLOW";
1802 }
1803 }));
1804 }
1805 }
1806
PropagateAllowThroughClear(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1807 void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
1808 const absl::flat_hash_set<int>& deny_set,
1809 absl::flat_hash_set<int>* allow_set) const {
1810 // Propagate allow from allow nodes through clearlist ops.
1811 absl::flat_hash_set<int> clear_prop_set;
1812 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1813 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1814 if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) ||
1815 !allow_set->count(root_idx)) {
1816 continue;
1817 }
1818 DfsTypeTraversal(
1819 graph_type_view_, {&root},
1820 TypeTraversalDirection::kFollowInputsAndOutputs,
1821 DfsTypePredicates::Enter([&](int idx) -> bool {
1822 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1823 return idx == root_idx ||
1824 (!allow_set->count(idx) && !deny_set.count(idx) &&
1825 ShouldProcess(*item.node) && IsFloat32(item) &&
1826 SupportsF16(item) &&
1827 (f16_clearlist_.count(item.node->op())) &&
1828 // We don't propagate (backwards) through nodes that read
1829 // Variables because it can break the behavior of TensorBoard
1830 // visualization and/or (in the case of Enter nodes) the model
1831 // itself. This is only a problem for non-resource variables.
1832 !NodeImplicitlyReadsNonResourceVariable(*item.node));
1833 }),
1834 DfsTypeCallbacks::PreOrder([&](int idx) {
1835 clear_prop_set.insert(idx);
1836 bool inserted = allow_set->insert(idx).second;
1837 if (VLOG_IS_ON(2) && inserted) {
1838 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1839 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1840 << " of " << item.node->op() << " node "
1841 << item.node->name() << " ALLOW";
1842 }
1843 }));
1844 }
1845 }
1846
1847 // Set infer node to allow if its immediate upstream node is in allow set
AddInferToAllowIfFollowAllow(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1848 void AutoMixedPrecisionImpl::AddInferToAllowIfFollowAllow(
1849 const absl::flat_hash_set<int>& deny_set,
1850 absl::flat_hash_set<int>* allow_set) const {
1851 // Currently only target for oneDNN
1852 if (mode_ != AutoMixedPrecisionMode::BF16) {
1853 return;
1854 }
1855 for (int item_idx = 0; item_idx < graph_type_view_.num_nodes(); ++item_idx) {
1856 const NodeTypeId& item = *graph_type_view_.GetNode(item_idx);
1857 if (!ShouldProcess(*item.node) || deny_set.count(item_idx) ||
1858 allow_set->count(item_idx) || !f16_inferlist_.count(item.node->op()) ||
1859 !IsFloat32(item) || !SupportsF16DataType(item)) {
1860 continue;
1861 }
1862
1863 bool has_allow_fanin = false;
1864 for (const int fanin : graph_type_view_.GetFanin(item_idx)) {
1865 if (deny_set.count(fanin)) {
1866 has_allow_fanin = false;
1867 break;
1868 }
1869 if (allow_set->count(fanin)) {
1870 has_allow_fanin = true;
1871 }
1872 }
1873 if (has_allow_fanin) {
1874 bool inserted = allow_set->insert(item_idx).second;
1875 if (VLOG_IS_ON(2) && inserted) {
1876 VLOG(2) << "Painting type " << item.type_attr.DebugString() << " of "
1877 << item.node->op() << " node " << item.node->name() << " ALLOW";
1878 }
1879 }
1880 }
1881 }
1882
1883 // If ops have one or more type_attr, But this type_attr could not be converted
1884 // to F16. Such as FusedBatchNormV2/FusedBatchNormV3, its type_attr 'U' only
1885 // support float. So we will remove this node from allow_set.
1886 // Also don't convert quantized ops to FP16.
RemoveAllowsetWithFp32(absl::flat_hash_set<int> * allow_set) const1887 void AutoMixedPrecisionImpl::RemoveAllowsetWithFp32(
1888 absl::flat_hash_set<int>* allow_set) const {
1889 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1890 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1891 if (f16_allowlist_.count(root.node->op()) && allow_set->count(root_idx) &&
1892 (!SupportsF16DataType(root) || IsQuantized(root))) {
1893 auto erased = allow_set->erase(root_idx);
1894 if (VLOG_IS_ON(2) && erased) {
1895 VLOG(2) << "UnPainting type " << root.type_attr.DebugString()
1896 << " of node " << root.node->name() << " ALLOW because its op "
1897 << root.node->op() << " is not support F16 DataType";
1898 }
1899 }
1900 }
1901 }
1902
1903 // Forces NextIteration nodes and their output Merge node(s) to have the same
1904 // color. Specifically, it removes them all from allow_set if any of the Merge
1905 // nodes is not in allow_set, otherwise it adds the NextIteration node to
1906 // allow_set.
ForceColorMatchOnRecurrentEdges(absl::flat_hash_set<int> * allow_set) const1907 Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
1908 absl::flat_hash_set<int>* allow_set) const {
1909 for (const NodeDef& node : graph_->node()) {
1910 if (node.op() == "NextIteration") {
1911 GraphView::OutputPort output_port(&node, 0);
1912 const auto& fanout = graph_view_.GetFanout(output_port);
1913 std::vector<int> merge_idxs;
1914 merge_idxs.reserve(fanout.size());
1915 bool any_merge_is_not_allow = false;
1916 for (const auto& output : fanout) {
1917 const NodeDef& merge_node = *output.node;
1918 if (merge_node.op() != "Merge") {
1919 return errors::FailedPrecondition(
1920 "Expected Merge node after NextIteration, got ", merge_node.op());
1921 }
1922 const absl::optional<int> maybe_merge_idx =
1923 graph_type_view_.GetNodeIndex(merge_node.name(), TypeAttrId("T"));
1924 if (!maybe_merge_idx.has_value()) {
1925 return errors::Internal("Type attribute T of Merge node ",
1926 merge_node.name(),
1927 " not found in graph view");
1928 }
1929 int merge_idx = maybe_merge_idx.value();
1930 merge_idxs.push_back(merge_idx);
1931 any_merge_is_not_allow =
1932 any_merge_is_not_allow || !allow_set->count(merge_idx);
1933 }
1934 const absl::optional<int> maybe_nextiter_idx =
1935 graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T"));
1936 if (!maybe_nextiter_idx.has_value()) {
1937 return errors::Internal("Type attribute T of NextIteration node ",
1938 node.name(), " not found in graph view");
1939 }
1940 int nextiter_idx = maybe_nextiter_idx.value();
1941 if (any_merge_is_not_allow) {
1942 for (int merge_idx : merge_idxs) {
1943 if (allow_set->erase(merge_idx)) {
1944 VLOG(2) << "Painting type T of Merge node "
1945 << graph_type_view_.GetNode(merge_idx)->node->name()
1946 << " DENY to match the color of its sibling Merge nodes "
1947 "with common NextIteration node "
1948 << node.name();
1949 }
1950 }
1951 if (allow_set->erase(nextiter_idx)) {
1952 VLOG(2) << "Painting type T of NextIteration node " << node.name()
1953 << " DENY to match the color of its output Merge node(s)";
1954 }
1955 } else {
1956 if (allow_set->insert(nextiter_idx).second) {
1957 VLOG(2) << "Painting type T of NextIteration node " << node.name()
1958 << " ALLOW to match the color of its output Merge node(s)";
1959 }
1960 }
1961 }
1962 }
1963 return OkStatus();
1964 }
1965
1966 // Forces all of the given Tensor List nodes into the same color set.
ForceColorMatchBetweenTensorListOps(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,absl::flat_hash_set<int> * allow_set,absl::flat_hash_set<int> * deny_set) const1967 void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
1968 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1969 absl::flat_hash_set<int>* allow_set,
1970 absl::flat_hash_set<int>* deny_set) const {
1971 bool any_deny = false;
1972 bool any_allow = false;
1973 std::vector<int> node_type_idxs;
1974 node_type_idxs.reserve(tensor_list_nodes.size());
1975 for (const NodeDef* node : tensor_list_nodes) {
1976 const NodeTypeId& node_type = *GetTensorListFloat32NodeTypeId(*node);
1977 const absl::optional<int> maybe_node_type_idx =
1978 graph_type_view_.GetNodeIndex(node_type);
1979 DCHECK(maybe_node_type_idx.has_value())
1980 << "Type attribute " << node_type.type_attr.DebugString() << " of node "
1981 << node->name() << " not found in graph view";
1982 node_type_idxs.push_back(maybe_node_type_idx.value());
1983 }
1984 for (int node_type_idx : node_type_idxs) {
1985 if (deny_set->count(node_type_idx)) {
1986 any_deny = true;
1987 break;
1988 } else if (allow_set->count(node_type_idx)) {
1989 any_allow = true;
1990 }
1991 }
1992 if (!any_deny && !any_allow) return;
1993 for (int node_type_idx : node_type_idxs) {
1994 const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx);
1995 VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of "
1996 << node_type.node->op() << " node " << node_type.node->name() << " "
1997 << (any_deny ? "DENY" : "ALLOW")
1998 << " because at least one of its siblings is "
1999 << (any_deny ? "DENY" : "ALLOW");
2000 if (any_deny) {
2001 allow_set->erase(node_type_idx);
2002 deny_set->insert(node_type_idx);
2003 } else {
2004 allow_set->insert(node_type_idx);
2005 }
2006 }
2007 }
2008
NodeImplicitlyReadsNonResourceVariable(const NodeDef & node) const2009 bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
2010 const NodeDef& node) const {
2011 if (node.op() == "Identity" || node.op() == "Enter") {
2012 GraphView::InputPort node_input(&node, 0);
2013 MutableGraphView::OutputPort prev_output =
2014 graph_view_.GetRegularFanin(node_input);
2015 const NodeDef* input = prev_output.node;
2016 if (input && ((node.op() == "Identity" && (input->op() == "Variable" ||
2017 input->op() == "VariableV2")) ||
2018 (node.op() == "Enter" &&
2019 NodeImplicitlyReadsNonResourceVariable(*input)))) {
2020 return true;
2021 }
2022 }
2023 return false;
2024 }
2025
2026 // This adds existing Cast nodes to allow_set if all of their outputs are allow,
2027 // avoiding the need to add a new Cast node after an existing Cast.
MakeCastsAllowIfAllOutputsAllow(absl::flat_hash_set<int> * allow_set) const2028 void AutoMixedPrecisionImpl::MakeCastsAllowIfAllOutputsAllow(
2029 absl::flat_hash_set<int>* allow_set) const {
2030 int num_nodes_preop = graph_->node_size();
2031 for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
2032 NodeDef* node = graph_->mutable_node(node_idx);
2033 NodeTypeId node_type(node, TypeAttrId("DstT"));
2034 if (node->op() != "Cast" || !IsFloat32(node_type)) {
2035 continue;
2036 }
2037 bool all_fanouts_allow = true;
2038 MutableGraphView::OutputPort src(node, 0);
2039 const auto& fanout = graph_view_.GetFanout(src);
2040 for (const MutableGraphView::InputPort& dst : fanout) {
2041 TypeAttrId dst_type_attr =
2042 node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
2043 const absl::optional<int> maybe_dst_type_idx =
2044 graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
2045 DCHECK(maybe_dst_type_idx.has_value())
2046 << "Type attribute " << dst_type_attr.DebugString() << " of node "
2047 << dst.node->name() << " not found in graph view";
2048 int dst_type_idx = maybe_dst_type_idx.value();
2049 bool dst_is_allow = allow_set->count(dst_type_idx);
2050 if (!dst_is_allow) {
2051 all_fanouts_allow = false;
2052 break;
2053 }
2054 }
2055 if (!fanout.empty() && all_fanouts_allow) {
2056 const absl::optional<int> maybe_node_type_idx =
2057 graph_type_view_.GetNodeIndex(node_type);
2058 DCHECK(maybe_node_type_idx.has_value())
2059 << "Type attribute " << node_type.type_attr.DebugString()
2060 << " of node " << node_type.node->name()
2061 << " not found in graph view";
2062 int node_type_idx = maybe_node_type_idx.value();
2063 allow_set->insert(node_type_idx);
2064 }
2065 }
2066 }
2067
2068 // Insert a Cast op at the output of a node.
2069 // CastType indicates the type of inserted Cast op
2070 // FP16: cast to float16
2071 // FP32: cast to float32
2072 // AUTO: cast to a data type that matches the fanout data type
InsertCastNodeAtFanout(const absl::flat_hash_set<int> & allow_set,const bool src_is_allow,const CastType & cast_type,MutableGraphView::OutputPort & src)2073 StatusOr<NodeDef*> AutoMixedPrecisionImpl::InsertCastNodeAtFanout(
2074 const absl::flat_hash_set<int>& allow_set, const bool src_is_allow,
2075 const CastType& cast_type, MutableGraphView::OutputPort& src) {
2076 NodeDef* added_cast_node = nullptr;
2077 // Note: This is copied so that edges can be modified inside the loop.
2078 auto fanout = graph_view_.GetFanout(src);
2079 for (const MutableGraphView::InputPort& dst : fanout) {
2080 TypeAttrId dst_type_attr =
2081 node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
2082 const absl::optional<int> maybe_dst_type_idx =
2083 graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
2084 if (!maybe_dst_type_idx.has_value()) {
2085 return errors::Internal("Type attribute ", dst_type_attr.DebugString(),
2086 " of ", dst.node->op(), " node ",
2087 dst.node->name(), " not found in graph view");
2088 }
2089 int dst_type_idx = maybe_dst_type_idx.value();
2090 bool dst_is_allow = allow_set.count(dst_type_idx);
2091 bool to_f16 = false;
2092 bool should_cast = false;
2093 switch (cast_type) {
2094 case CastType::AUTO:
2095 if (src_is_allow != dst_is_allow) {
2096 to_f16 = dst_is_allow;
2097 should_cast = true;
2098 }
2099 break;
2100 case CastType::FP16:
2101 to_f16 = true;
2102 should_cast = true;
2103 break;
2104 case CastType::FP32:
2105 to_f16 = false;
2106 should_cast = true;
2107 break;
2108 default:
2109 return errors::Internal("Invalid Cast Type: ",
2110 static_cast<int>(cast_type));
2111 }
2112
2113 if (!should_cast) continue;
2114 if (added_cast_node == nullptr) {
2115 VLOG(1) << "Inserting cast to "
2116 << (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT") << " at "
2117 << src.node->op() << " " << src.node->name() << ":"
2118 << src.port_id;
2119 added_cast_node = graph_view_.AddNode(
2120 BuildCastNode(src, dst, to_f16, src.node->device()));
2121 if (to_f16 && !IsConstant(*src.node) && !IsVariable(*src.node) &&
2122 !NodeImplicitlyReadsNonResourceVariable(*src.node)) {
2123 ++num_nonvar_casts_to_f16_;
2124 }
2125 }
2126 TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
2127 dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
2128 }
2129 return added_cast_node;
2130 }
2131
2132 // Get the destination data type of a cast op. Return error if the node is not
2133 // a Cast op.
GetCastToType(const NodeDef * node) const2134 StatusOr<DataType> AutoMixedPrecisionImpl::GetCastToType(
2135 const NodeDef* node) const {
2136 CHECK_EQ(node->op(), "Cast") // Crash OK
2137 << "Node " << node->name() << " is not a Cast op";
2138 return node->attr().at("DstT").type();
2139 }
2140
2141 // Collect the output ports of a node based on a type attribute and append them
2142 // to a vector.
2143 // Input: type_attr
2144 // Input: node
2145 // Output: output_ports
CollectOutputPorts(const TypeAttrId & type_attr,NodeDef * node,std::vector<MutableGraphView::OutputPort> & output_ports) const2146 void AutoMixedPrecisionImpl::CollectOutputPorts(
2147 const TypeAttrId& type_attr, NodeDef* node,
2148 std::vector<MutableGraphView::OutputPort>& output_ports) const {
2149 for (int port_id : node_type_map_.GetOutputPorts(*node, type_attr)) {
2150 output_ports.emplace_back(node, port_id);
2151 }
2152 }
2153
2154 // Changes all allow-painted type attributes to DT_HALF or DT_BFLOAT16, and
2155 // inserts Cast nodes at node outputs for all edges that connect
2156 // allow-painted <-> non-allow-painted type attributes.
ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int> & allow_set)2157 Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
2158 const absl::flat_hash_set<int>& allow_set) {
2159 int num_nodes_changed = 0;
2160 const int num_nodes_preop = graph_->node_size();
2161
2162 bool emulate_f16 = false;
2163 if (mode_ == AutoMixedPrecisionMode::CPU) {
2164 TF_CHECK_OK(
2165 ReadBoolFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_EMULATE_FP16",
2166 /*default_val=*/true, &emulate_f16));
2167 }
2168
2169 VLOG(1) << "Setting emulate_f16 = " << emulate_f16;
2170
2171 for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
2172 NodeDef* node = graph_->mutable_node(node_idx);
2173 for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(*node)) {
2174 const absl::optional<int> maybe_node_type_idx =
2175 graph_type_view_.GetNodeIndex(node->name(), type_attr);
2176 if (!maybe_node_type_idx.has_value()) {
2177 return errors::Internal("Type attribute ", type_attr.DebugString(),
2178 " of ", node->op(), " node ", node->name(),
2179 " not found in graph view");
2180 }
2181 int node_type_idx = maybe_node_type_idx.value();
2182 if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
2183 bool src_is_allow = allow_set.count(node_type_idx);
2184
2185 // Include output ports of fp32 nodes, real fp16 nodes,
2186 // and the fp16 Cast nodes at the fanout of emulated fp16 ops.
2187 std::vector<MutableGraphView::OutputPort> output_ports;
2188
2189 if (src_is_allow) {
2190 if (emulate_f16) {
2191 // For emulated fp16 op, we do not change the op type but instead
2192 // insert fp32 Cast at the fanin and fp16 Cast at the fanout
2193 for (int port_id : node_type_map_.GetInputPorts(*node, type_attr)) {
2194 VLOG(2) << "Cast to F32 at fanin of node " << node->name() << ":"
2195 << port_id;
2196 MutableGraphView::InputPort dst(node, port_id);
2197 MutableGraphView::OutputPort src = graph_view_.GetRegularFanin(dst);
2198 NodeDef* added_cast_node = graph_view_.AddNode(
2199 BuildCastNode(src, dst, /*to_f16=*/false, src.node->device()));
2200 VLOG(1) << "Inserting cast to DT_FLOAT at " << src.node->op() << " "
2201 << src.node->name() << ":" << src.port_id;
2202 TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
2203 dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
2204 }
2205 // Cast to fp16 at outputs
2206 for (int port_id : node_type_map_.GetOutputPorts(*node, type_attr)) {
2207 MutableGraphView::OutputPort src(node, port_id);
2208 VLOG(2) << "Cast to F16 at fanout of node " << node->name() << ":"
2209 << port_id;
2210 TF_ASSIGN_OR_RETURN(NodeDef * added_cast_node,
2211 InsertCastNodeAtFanout(allow_set, src_is_allow,
2212 CastType::FP16, src));
2213 if (added_cast_node != nullptr) {
2214 output_ports.emplace_back(added_cast_node, /*port_id=*/0);
2215 }
2216 }
2217 } else {
2218 VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
2219 << node->op() << " node " << node->name() << " to "
2220 << DataTypeString(target_dtype_);
2221 if (!SetDataType(node, type_attr, target_dtype_)) {
2222 return errors::Internal("Failed to set type attribute");
2223 }
2224 ++num_nodes_changed;
2225 CollectOutputPorts(type_attr, node, output_ports);
2226 }
2227 } else {
2228 CollectOutputPorts(type_attr, node, output_ports);
2229 }
2230
2231 // If the fanouts require a different data type from the output of the
2232 // current node, insert a Cast op.
2233 for (auto output_port : output_ports) {
2234 VLOG(2) << "Cast to required data type at fanout of node "
2235 << output_port.node->name() << ":" << output_port.port_id;
2236 TF_RETURN_IF_ERROR(InsertCastNodeAtFanout(allow_set, src_is_allow,
2237 CastType::AUTO, output_port)
2238 .status());
2239 }
2240 }
2241 }
2242
2243 // Use Python type names (e.g. float16) instead of C++ type names (e.g. half)
2244 // since many Python users will see this message.
2245 const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16";
2246 LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
2247 << " nodes to " << type_str << " precision using "
2248 << num_nonvar_casts_to_f16_ << " cast(s) to " << type_str
2249 << " (excluding Const and Variable casts)";
2250 return OkStatus();
2251 }
2252
GetNumGPUs(const Cluster & cluster)2253 int GetNumGPUs(const Cluster& cluster) {
2254 if (ShouldSimulateGpu()) {
2255 return 1;
2256 }
2257 auto devices = cluster.GetDevices();
2258 int num_gpus = 0;
2259 for (const auto& device : devices) {
2260 const DeviceProperties& device_properties = device.second;
2261 if (device_properties.type() == "GPU" &&
2262 (ShouldIgnorePerformance() || HasFastFP16Support(device_properties))) {
2263 num_gpus++;
2264 }
2265 }
2266 return num_gpus;
2267 }
2268
2269 } // end namespace
2270
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)2271 Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
2272 GraphDef* output) {
2273 if (cluster == nullptr) {
2274 return errors::InvalidArgument("cluster == nullptr");
2275 }
2276
2277 #if !defined(INTEL_MKL)
2278 if (mode_ == AutoMixedPrecisionMode::BF16) {
2279 return errors::Unimplemented(
2280 "The auto_mixed_precision_onednn_bfloat16 optimizer cannot be used "
2281 "since this build of TensorFlow is not compiled with oneDNN support "
2282 "for bfloat16. "
2283 "For information on oneDNN builds, see: "
2284 "https://software.intel.com/en-us/articles/intel-optimization-for-"
2285 "tensorflow-installation-guide");
2286 }
2287 #endif // INTEL_MKL
2288
2289 // Start by copying input graph to output.
2290 *output = item.graph;
2291
2292 int num_gpus = GetNumGPUs(*cluster);
2293 if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) {
2294 // AutoMixedPrecision is currently only tuned for GPU.
2295 LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
2296 << " graph optimizer";
2297 return OkStatus();
2298 }
2299
2300 if (num_gpus >= 1 && mode_ == AutoMixedPrecisionMode::BF16) {
2301 LOG(WARNING) << "Note: GPUs detected. Using " << name()
2302 << " graph optimizer configured for BFloat16 on CPUs";
2303 }
2304
2305 // Optimize the output graph in-place.
2306 AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
2307 item.id, mode_);
2308 if (item.id == "tf_graph") {
2309 LOG(INFO) << "Running " << name() << " graph optimizer";
2310 } else {
2311 VLOG(1) << "Running " << name() << " graph optimizer on " << item.id;
2312 }
2313 Status status = optimizer.Optimize();
2314 if (!status.ok()) {
2315 // Restore the original graph.
2316 *output = item.graph;
2317 LOG(WARNING) << name() << " graph optimizer FAILED: " << status.ToString();
2318 }
2319 return status;
2320 }
2321
2322 } // end namespace grappler
2323 } // end namespace tensorflow
2324