1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/platform_util.h"
17
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21
22 #include "absl/strings/ascii.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/debug_options_flags.h"
25 #include "tensorflow/compiler/xla/service/compiler.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
33
34 namespace xla {
35
36 // Minimum supported CUDA compute capability is 3.5.
37 constexpr int kMinCudaComputeCapabilityMajor = 3;
38 constexpr int kMinCudaComputeCapabilityMinor = 5;
39
40 // The name of the interpreter platform.
41 constexpr char kInterpreter[] = "interpreter";
42
43 namespace {
44
CanonicalPlatformName(const std::string & platform_name)45 std::string CanonicalPlatformName(const std::string& platform_name) {
46 std::string lowercase_platform_name = absl::AsciiStrToLower(platform_name);
47 // "cpu" and "host" mean the same thing.
48 if (lowercase_platform_name == "cpu") {
49 return "host";
50 }
51 // When configured on CUDA, "gpu" and "cuda" mean the same thing.
52 // When configured on ROCm, "gpu" and "rocm" mean the same thing.
53 if (lowercase_platform_name == "gpu") {
54 #if TENSORFLOW_USE_ROCM
55 return "rocm";
56 #else
57 return "cuda";
58 #endif
59 }
60 return lowercase_platform_name;
61 }
62
GetSupportedPlatforms()63 StatusOr<std::vector<se::Platform*>> GetSupportedPlatforms() {
64 return se::MultiPlatformManager::PlatformsWithFilter(
65 [](const se::Platform* platform) {
66 auto compiler_status = Compiler::GetForPlatform(platform);
67 bool supported = compiler_status.ok();
68 if (!supported) {
69 LOG(INFO) << "platform " << platform->Name() << " present but no "
70 << "XLA compiler available: "
71 << compiler_status.status().error_message();
72 }
73 return supported;
74 });
75 }
76
77 } // namespace
78
79 /* static */ StatusOr<std::vector<se::Platform*>>
GetSupportedPlatforms()80 PlatformUtil::GetSupportedPlatforms() {
81 // Gather all platforms which have an XLA compiler.
82 return xla::GetSupportedPlatforms();
83 }
84
GetDefaultPlatform()85 /* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
86 TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
87
88 se::Platform* platform = nullptr;
89 if (platforms.empty()) {
90 return NotFound("no platforms found");
91 } else if (platforms.size() == 1) {
92 platform = platforms[0];
93 } else if (platforms.size() == 2) {
94 for (int i = 0; i < 2; i++) {
95 if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter &&
96 absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) {
97 platform = platforms[1 - i];
98 break;
99 }
100 }
101 }
102 if (platform != nullptr) {
103 return platform;
104 }
105
106 // Multiple platforms present and we can't pick a reasonable default.
107 std::string platforms_string = absl::StrJoin(
108 platforms, ", ",
109 [](std::string* out, const se::Platform* p) { out->append(p->Name()); });
110 return InvalidArgument(
111 "must specify platform because more than one platform (except for the "
112 "interpreter platform) found: %s.",
113 platforms_string);
114 }
115
GetPlatform(const std::string & platform_name)116 /*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
117 const std::string& platform_name) {
118 TF_ASSIGN_OR_RETURN(se::Platform * platform,
119 se::MultiPlatformManager::PlatformWithName(
120 CanonicalPlatformName(platform_name)));
121 TF_RETURN_IF_ERROR(Compiler::GetForPlatform(platform).status());
122 return platform;
123 }
124
125 // Returns whether the device underlying the given StreamExecutor is supported
126 // by XLA.
IsDeviceSupported(se::StreamExecutor * executor)127 static bool IsDeviceSupported(se::StreamExecutor* executor) {
128 const auto& description = executor->GetDeviceDescription();
129 if (executor->platform()->id() == se::cuda::kCudaPlatformId) {
130 // CUDA devices must have a minimum compute capability.
131 se::CudaComputeCapability cc = description.cuda_compute_capability();
132 if (!cc.IsAtLeast(kMinCudaComputeCapabilityMajor,
133 kMinCudaComputeCapabilityMinor)) {
134 LOG(INFO) << "StreamExecutor cuda device (" << executor->device_ordinal()
135 << ") is of "
136 << "insufficient compute capability: "
137 << kMinCudaComputeCapabilityMajor << "."
138 << kMinCudaComputeCapabilityMinor << " required, "
139 << "device is " << cc.ToString();
140 return false;
141 }
142 } else if (executor->platform()->id() == se::rocm::kROCmPlatformId) {
143 auto rocm_compute_capability = description.rocm_compute_capability();
144 if (!rocm_compute_capability.is_supported_gfx_version()) {
145 LOG(INFO) << "StreamExecutor ROCM device (" << executor->device_ordinal()
146 << ") is of unsupported "
147 << "AMDGPU version : " << rocm_compute_capability.gfx_version()
148 << ". The supported AMDGPU versions are "
149 << rocm_compute_capability.supported_gfx_versions_str() << ".";
150 return false;
151 }
152 }
153 return true;
154 }
155
156 /* static */ StatusOr<std::vector<se::StreamExecutor*>>
GetStreamExecutors(se::Platform * platform,const std::optional<std::set<int>> & allowed_devices)157 PlatformUtil::GetStreamExecutors(
158 se::Platform* platform,
159 const std::optional<std::set<int>>& allowed_devices) {
160 int device_count = platform->VisibleDeviceCount();
161 if (device_count <= 0) {
162 return NotFound("no %s devices found", platform->Name());
163 }
164 if (platform->id() == se::host::kHostPlatformId) {
165 // On host "devices", StreamExecutor exports a device for each hardware
166 // thread. Because we parallelize a single computation across threads, it
167 // doesn't make sense to expose these as separate devices, so by default we
168 // fix the number of devices to one. However we do let the user override
169 // this behavior to help run tests on the host that run models in parallel
170 // across multiple devices.
171 device_count =
172 GetDebugOptionsFromFlags().xla_force_host_platform_device_count();
173 }
174 std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
175 VLOG(1) << "Initializing devices";
176 {
177 tensorflow::thread::ThreadPool thread_pool(
178 tensorflow::Env::Default(), "device_initialization", device_count);
179 auto create_fn = [](se::Platform* platform,
180 std::vector<se::StreamExecutor*>& stream_executors,
181 int device_ordinal, int count) {
182 VLOG(1) << "Started device init " << device_ordinal;
183 auto executor_status = platform->ExecutorForDevice(device_ordinal);
184 if (executor_status.ok()) {
185 se::StreamExecutor* executor = executor_status.ValueOrDie();
186 if (IsDeviceSupported(executor)) {
187 stream_executors[count] = executor;
188 }
189 } else {
190 LOG(WARNING) << "unable to create StreamExecutor for "
191 << platform->Name() << ":" << device_ordinal << ": "
192 << executor_status.status().error_message();
193 }
194 VLOG(1) << "Finished device init " << device_ordinal;
195 };
196 // Once a stream executor is instantiated it will cause allocations on
197 // the device, for example for GPUs cuda context, cudnn handles etc. will
198 // be constructed. By constructing stream executors only on the
199 // allowed_devices, we don't make any allocations on other devices.
200 // This helps in multi-process executions on the same host like horovod or
201 // shared hosts.
202 if (allowed_devices) {
203 int count = 0;
204 for (const auto& i : *allowed_devices) {
205 if (count >= device_count) {
206 break;
207 }
208 thread_pool.Schedule(
209 [platform, &stream_executors, i, count, &create_fn]() {
210 create_fn(platform, stream_executors, i, count);
211 });
212 count++;
213 }
214 } else {
215 for (int i = 0; i < device_count; ++i) {
216 thread_pool.Schedule([platform, &stream_executors, i, &create_fn]() {
217 create_fn(platform, stream_executors, i, i);
218 });
219 }
220 }
221 // Block here in thread_pool destructor until all devices are initialized.
222 }
223 VLOG(1) << "Device initialization complete";
224
225 std::vector<se::StreamExecutor*> out;
226 for (se::StreamExecutor* executor : stream_executors) {
227 if (executor != nullptr) {
228 out.push_back(executor);
229 }
230 }
231 if (out.empty()) {
232 return InternalError("no supported devices found for platform %s",
233 platform->Name());
234 }
235 return out;
236 }
237
238 } // namespace xla
239