1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <unordered_set> 23 #include <vector> 24 25 #include "absl/container/flat_hash_set.h" 26 #include "tensorflow/core/common_runtime/device.h" 27 #include "tensorflow/core/lib/core/arena.h" 28 #include "tensorflow/core/lib/core/status.h" 29 #include "tensorflow/core/lib/core/stringpiece.h" 30 #include "tensorflow/core/lib/gtl/inlined_vector.h" 31 #include "tensorflow/core/platform/macros.h" 32 33 namespace tensorflow { 34 35 class DeviceAttributes; 36 37 // Represents a set of devices. 38 class DeviceMgr { 39 public: 40 DeviceMgr() = default; 41 virtual ~DeviceMgr(); 42 43 // Returns attributes of all devices. 44 virtual void ListDeviceAttributes( 45 std::vector<DeviceAttributes>* devices) const = 0; 46 47 // Returns raw pointers to the underlying devices. 48 virtual std::vector<Device*> ListDevices() const = 0; 49 50 // Returns a string listing all devices. 51 virtual string DebugString() const = 0; 52 53 // Returns a string of all the device mapping. 54 virtual string DeviceMappingString() const = 0; 55 56 // Assigns *device with pointer to Device of the given name. 57 // Accepts either a full device name, or just the replica-local suffix. 58 virtual Status LookupDevice(StringPiece name, Device** device) const = 0; 59 60 // Check if the current device manager contains device with the given 61 // incarnation ID. Looking up by incarnation IDs because they are randomly 62 // generated and not intentionally reused (unlike device pointers). 63 virtual bool ContainsDevice(int64_t device_incarnation) const = 0; 64 65 // Clears given containers of all devices if 'container' is 66 // non-empty. Otherwise, clears default containers of all devices. 67 virtual void ClearContainers(gtl::ArraySlice<string> containers) const = 0; 68 69 virtual int NumDeviceType(const string& type) const = 0; 70 71 // Returns an arbitrary CPU device if one is present, otherwise return 72 // nullptr. 73 virtual Device* HostCPU() const = 0; 74 75 TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr); 76 }; 77 78 // Represents a static set of devices. 79 class StaticDeviceMgr : public DeviceMgr { 80 public: 81 // Constructs a StaticDeviceMgr from a list of devices. 82 explicit StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices); 83 84 // Constructs a StaticDeviceMgr managing a single device. 85 explicit StaticDeviceMgr(std::unique_ptr<Device> device); 86 87 ~StaticDeviceMgr() override; 88 89 void ListDeviceAttributes( 90 std::vector<DeviceAttributes>* devices) const override; 91 std::vector<Device*> ListDevices() const override; 92 string DebugString() const override; 93 string DeviceMappingString() const override; 94 Status LookupDevice(StringPiece name, Device** device) const override; 95 bool ContainsDevice(int64_t device_incarnation) const override; 96 void ClearContainers(gtl::ArraySlice<string> containers) const override; 97 int NumDeviceType(const string& type) const override; 98 Device* HostCPU() const override; 99 100 private: 101 const std::vector<std::unique_ptr<Device>> devices_; 102 103 StringPiece CopyToBackingStore(StringPiece s); 104 105 absl::flat_hash_set<int64_t> device_incarnation_set_; 106 std::unordered_map<StringPiece, Device*, StringPieceHasher> device_map_; 107 core::Arena name_backing_store_; // Storage for keys in device_map_ 108 std::unordered_map<string, int> device_type_counts_; 109 Device* cpu_device_; 110 111 TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr); 112 }; 113 114 // Size of stale device buffer for temporary storage of removed devices. 115 static const size_t kStaleDeviceBufferSize = 8192; 116 117 // Represents a dynamic set of devices 118 class DynamicDeviceMgr : public DeviceMgr { 119 public: 120 // Constructs an empty DynamicDeviceMgr. 121 DynamicDeviceMgr(); 122 123 // Constructs a DynamicDeviceMgr from a list of devices. 124 // TODO(b/183966398): Remove StaticDeviceMgr since there's no usage. 125 explicit DynamicDeviceMgr(std::vector<std::unique_ptr<Device>> devices); 126 127 ~DynamicDeviceMgr() override; 128 129 void ListDeviceAttributes( 130 std::vector<DeviceAttributes>* devices) const override; 131 std::vector<Device*> ListDevices() const override; 132 string DebugString() const override; 133 string DeviceMappingString() const override; 134 Status LookupDevice(StringPiece name, Device** device) const override; 135 bool ContainsDevice(int64_t device_incarnation) const override; 136 void ClearContainers(gtl::ArraySlice<string> containers) const override; 137 int NumDeviceType(const string& type) const override; 138 Device* HostCPU() const override; 139 140 // Add devices to device manager. Returns error for repeated device names. 141 Status AddDevices(std::vector<std::unique_ptr<Device>> devices); 142 143 // Remove devices from device manager. 144 // Returns error for non-existing devices or if the HostCPU() device is in the 145 // input list. If an error is returned, the device list is not modified. 146 Status RemoveDevices(const std::vector<Device*>& devices); 147 148 // Remove devices from device manager by their names. Returns error for 149 // non-existing devices or if the HostCPU() device is given in the input list. 150 // If an error is returned, the device list is not modified. 151 Status RemoveDevicesByName(const std::vector<string>& device_names); 152 153 private: 154 mutable mutex devices_mu_; 155 156 std::vector<std::unique_ptr<Device>> dynamic_devices_ 157 TF_GUARDED_BY(devices_mu_); 158 159 absl::flat_hash_set<int64_t> device_incarnation_set_ 160 TF_GUARDED_BY(devices_mu_); 161 std::unordered_map<string, Device*> device_map_ TF_GUARDED_BY(devices_mu_); 162 163 std::unordered_map<string, int> device_type_counts_ 164 TF_GUARDED_BY(devices_mu_); 165 166 mutable std::atomic<Device*> cpu_device_; // memoize `HostCPU` result 167 168 class DeviceCircularBuffer { 169 public: DeviceCircularBuffer()170 DeviceCircularBuffer() : index_(0) { 171 devices_.resize(kStaleDeviceBufferSize); 172 } add(std::unique_ptr<Device> device)173 void add(std::unique_ptr<Device> device) { 174 devices_[index_] = std::move(device); 175 index_ = (index_ + 1) % kStaleDeviceBufferSize; 176 } 177 178 private: 179 int index_; 180 std::vector<std::unique_ptr<Device>> devices_; 181 }; 182 183 // Buffer to temporarily store the removed devices. Raw device pointers are 184 // accessible to DeviceSet, and if the function instantiation process directly 185 // access fields through the device set, the underlying device object must 186 // still be available to avoid segmentation fault. We keep the devices in this 187 // buffer only for that purpose. 188 DeviceCircularBuffer stale_devices_ TF_GUARDED_BY(devices_mu_); 189 190 TF_DISALLOW_COPY_AND_ASSIGN(DynamicDeviceMgr); 191 }; 192 } // namespace tensorflow 193 194 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ 195