1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_ 17 18 #include <unordered_map> 19 #include <vector> 20 21 #include "tensorflow/core/common_runtime/device_set.h" 22 #include "tensorflow/core/framework/function.h" 23 #include "tensorflow/core/lib/core/status.h" 24 25 namespace tensorflow { 26 27 // Given a `device_set` and a `graph`, partitions the `graph` into 28 // `subgraphs`. `subgraphs` maps device names to the graph assigned to that 29 // device. `graph` must have been placed (e.g. by running Placer), 30 // i.e. all nodes must have an assigned_device set. 31 // `graph` is non-const because the underlying Partition() function transforms 32 // the graph to correctly partition distributed control flow. 33 // `get_tensor_name_attr` computes the "tensor_name" attr value of Send/Recv ops 34 // inserted during partitioning. Use the default one if not set. It needs to be 35 // thread safe if it's shared in multple threads. 36 Status PartitionFunctionGraph( 37 const DeviceSet& device_set, std::unique_ptr<Graph> graph, 38 std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs, 39 std::function<string(const Edge*)> get_tensor_name_attr = nullptr); 40 41 // Inserts send/recv ops to `graph` if nodes are assigned to multiple devices. 42 // Returns the new graph with the added nodes. 43 StatusOr<std::unique_ptr<Graph>> InsertTransferOps( 44 const DeviceSet& device_set, std::unique_ptr<Graph> graph); 45 46 // This function performs bookkeeping to track which `Arg` and `Retval` nodes 47 // were placed on a particular device / graph. 48 // 49 // More specifically, this function 50 // 51 // (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be 52 // consecutive. 53 // 54 // These indices might not be consecutive after grappler's pruning 55 // optimization (e.g. removing redundant Args), or graph partitioning. In 56 // the latter case, the nodes in `graph` are placed on `device_type`, and 57 // each such graph partition gets a subset of the arguments and return 58 // values. The `index` attributes of these _Arg and _Retval nodes reflect 59 // the indices of these parameters in the original function. To convert 60 // `subgraph` to a function, we need to replace there original indices with 61 // 0, 1, 2, ... . 62 // 63 // The argument and return value order in `graph` is determined by the 64 // argument and return value order in the original function. This stability 65 // is important because it enables us to treat a single-partition function 66 // as having the same signature as the subgraph. 67 // 68 // (2) records the subsets of `Arg` and `Retval` nodes assigned to the 69 // device in `*_indices`, and 70 // (3) records which `Arg` and `Retval` nodes live in host memory in 71 // `*_alloc_attrs`. If these vectors are NULL, do nothing here. If 72 // `ints_on_device` is false, int32 `Arg` and `Retval` nodes are placed on 73 // host else not. This is needed because in certain special cases e.g. 74 // when graph is placed on TPU/XLA device or when the `Retval` is an output 75 // of an iterator, int32 tensors live on device. 76 Status UpdateArgAndRetvalMetadata( 77 Graph* graph, std::vector<FunctionArgIndex>* arg_indices, 78 std::vector<int>* ret_indices, 79 std::vector<AllocatorAttributes>* arg_alloc_attrs, 80 std::vector<AllocatorAttributes>* ret_alloc_attrs, bool ints_on_device); 81 82 // Utility for generating function names not present in `flib_def`, using 83 // given `name` as the base for the name. 84 class FunctionNameGenerator { 85 public: 86 // `flib_def` must outlive this. FunctionNameGenerator(const FunctionLibraryDefinition * flib_def,const string & name)87 FunctionNameGenerator(const FunctionLibraryDefinition* flib_def, 88 const string& name) 89 : flib_def_(flib_def), name_(name), counter_(0) {} 90 91 // Returns a function name not present in `flib_def` using `name` as 92 // the base and appending a numeric suffix. 93 string GetName(); 94 95 private: 96 const FunctionLibraryDefinition* flib_def_; 97 const string name_; 98 uint32 counter_; 99 }; 100 101 } // namespace tensorflow 102 103 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_ 104