xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/device_mgr.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
18 
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/lib/core/arena.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 #include "tensorflow/core/platform/macros.h"
32 
33 namespace tensorflow {
34 
35 class DeviceAttributes;
36 
37 // Represents a set of devices.
38 class DeviceMgr {
39  public:
40   DeviceMgr() = default;
41   virtual ~DeviceMgr();
42 
43   // Returns attributes of all devices.
44   virtual void ListDeviceAttributes(
45       std::vector<DeviceAttributes>* devices) const = 0;
46 
47   // Returns raw pointers to the underlying devices.
48   virtual std::vector<Device*> ListDevices() const = 0;
49 
50   // Returns a string listing all devices.
51   virtual string DebugString() const = 0;
52 
53   // Returns a string of all the device mapping.
54   virtual string DeviceMappingString() const = 0;
55 
56   // Assigns *device with pointer to Device of the given name.
57   // Accepts either a full device name, or just the replica-local suffix.
58   virtual Status LookupDevice(StringPiece name, Device** device) const = 0;
59 
60   // Check if the current device manager contains device with the given
61   // incarnation ID. Looking up by incarnation IDs because they are randomly
62   // generated and not intentionally reused (unlike device pointers).
63   virtual bool ContainsDevice(int64_t device_incarnation) const = 0;
64 
65   // Clears given containers of all devices if 'container' is
66   // non-empty. Otherwise, clears default containers of all devices.
67   virtual void ClearContainers(gtl::ArraySlice<string> containers) const = 0;
68 
69   virtual int NumDeviceType(const string& type) const = 0;
70 
71   // Returns an arbitrary CPU device if one is present, otherwise return
72   // nullptr.
73   virtual Device* HostCPU() const = 0;
74 
75   TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
76 };
77 
78 // Represents a static set of devices.
79 class StaticDeviceMgr : public DeviceMgr {
80  public:
81   // Constructs a StaticDeviceMgr from a list of devices.
82   explicit StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices);
83 
84   // Constructs a StaticDeviceMgr managing a single device.
85   explicit StaticDeviceMgr(std::unique_ptr<Device> device);
86 
87   ~StaticDeviceMgr() override;
88 
89   void ListDeviceAttributes(
90       std::vector<DeviceAttributes>* devices) const override;
91   std::vector<Device*> ListDevices() const override;
92   string DebugString() const override;
93   string DeviceMappingString() const override;
94   Status LookupDevice(StringPiece name, Device** device) const override;
95   bool ContainsDevice(int64_t device_incarnation) const override;
96   void ClearContainers(gtl::ArraySlice<string> containers) const override;
97   int NumDeviceType(const string& type) const override;
98   Device* HostCPU() const override;
99 
100  private:
101   const std::vector<std::unique_ptr<Device>> devices_;
102 
103   StringPiece CopyToBackingStore(StringPiece s);
104 
105   absl::flat_hash_set<int64_t> device_incarnation_set_;
106   std::unordered_map<StringPiece, Device*, StringPieceHasher> device_map_;
107   core::Arena name_backing_store_;  // Storage for keys in device_map_
108   std::unordered_map<string, int> device_type_counts_;
109   Device* cpu_device_;
110 
111   TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr);
112 };
113 
114 // Size of stale device buffer for temporary storage of removed devices.
115 static const size_t kStaleDeviceBufferSize = 8192;
116 
117 // Represents a dynamic set of devices
118 class DynamicDeviceMgr : public DeviceMgr {
119  public:
120   // Constructs an empty DynamicDeviceMgr.
121   DynamicDeviceMgr();
122 
123   // Constructs a DynamicDeviceMgr from a list of devices.
124   // TODO(b/183966398): Remove StaticDeviceMgr since there's no usage.
125   explicit DynamicDeviceMgr(std::vector<std::unique_ptr<Device>> devices);
126 
127   ~DynamicDeviceMgr() override;
128 
129   void ListDeviceAttributes(
130       std::vector<DeviceAttributes>* devices) const override;
131   std::vector<Device*> ListDevices() const override;
132   string DebugString() const override;
133   string DeviceMappingString() const override;
134   Status LookupDevice(StringPiece name, Device** device) const override;
135   bool ContainsDevice(int64_t device_incarnation) const override;
136   void ClearContainers(gtl::ArraySlice<string> containers) const override;
137   int NumDeviceType(const string& type) const override;
138   Device* HostCPU() const override;
139 
140   // Add devices to device manager. Returns error for repeated device names.
141   Status AddDevices(std::vector<std::unique_ptr<Device>> devices);
142 
143   // Remove devices from device manager.
144   // Returns error for non-existing devices or if the HostCPU() device is in the
145   // input list. If an error is returned, the device list is not modified.
146   Status RemoveDevices(const std::vector<Device*>& devices);
147 
148   // Remove devices from device manager by their names. Returns error for
149   // non-existing devices or if the HostCPU() device is given in the input list.
150   // If an error is returned, the device list is not modified.
151   Status RemoveDevicesByName(const std::vector<string>& device_names);
152 
153  private:
154   mutable mutex devices_mu_;
155 
156   std::vector<std::unique_ptr<Device>> dynamic_devices_
157       TF_GUARDED_BY(devices_mu_);
158 
159   absl::flat_hash_set<int64_t> device_incarnation_set_
160       TF_GUARDED_BY(devices_mu_);
161   std::unordered_map<string, Device*> device_map_ TF_GUARDED_BY(devices_mu_);
162 
163   std::unordered_map<string, int> device_type_counts_
164       TF_GUARDED_BY(devices_mu_);
165 
166   mutable std::atomic<Device*> cpu_device_;  // memoize `HostCPU` result
167 
168   class DeviceCircularBuffer {
169    public:
DeviceCircularBuffer()170     DeviceCircularBuffer() : index_(0) {
171       devices_.resize(kStaleDeviceBufferSize);
172     }
add(std::unique_ptr<Device> device)173     void add(std::unique_ptr<Device> device) {
174       devices_[index_] = std::move(device);
175       index_ = (index_ + 1) % kStaleDeviceBufferSize;
176     }
177 
178    private:
179     int index_;
180     std::vector<std::unique_ptr<Device>> devices_;
181   };
182 
183   // Buffer to temporarily store the removed devices. Raw device pointers are
184   // accessible to DeviceSet, and if the function instantiation process directly
185   // access fields through the device set, the underlying device object must
186   // still be available to avoid segmentation fault. We keep the devices in this
187   // buffer only for that purpose.
188   DeviceCircularBuffer stale_devices_ TF_GUARDED_BY(devices_mu_);
189 
190   TF_DISALLOW_COPY_AND_ASSIGN(DynamicDeviceMgr);
191 };
192 }  // namespace tensorflow
193 
194 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
195