xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/platform_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/platform_util.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/strings/ascii.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/debug_options_flags.h"
25 #include "tensorflow/compiler/xla/service/compiler.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
33 
34 namespace xla {
35 
36 // Minimum supported CUDA compute capability is 3.5.
37 constexpr int kMinCudaComputeCapabilityMajor = 3;
38 constexpr int kMinCudaComputeCapabilityMinor = 5;
39 
40 // The name of the interpreter platform.
41 constexpr char kInterpreter[] = "interpreter";
42 
43 namespace {
44 
CanonicalPlatformName(const std::string & platform_name)45 std::string CanonicalPlatformName(const std::string& platform_name) {
46   std::string lowercase_platform_name = absl::AsciiStrToLower(platform_name);
47   // "cpu" and "host" mean the same thing.
48   if (lowercase_platform_name == "cpu") {
49     return "host";
50   }
51   // When configured on CUDA, "gpu" and "cuda" mean the same thing.
52   // When configured on ROCm, "gpu" and "rocm" mean the same thing.
53   if (lowercase_platform_name == "gpu") {
54 #if TENSORFLOW_USE_ROCM
55     return "rocm";
56 #else
57     return "cuda";
58 #endif
59   }
60   return lowercase_platform_name;
61 }
62 
GetSupportedPlatforms()63 StatusOr<std::vector<se::Platform*>> GetSupportedPlatforms() {
64   return se::MultiPlatformManager::PlatformsWithFilter(
65       [](const se::Platform* platform) {
66         auto compiler_status = Compiler::GetForPlatform(platform);
67         bool supported = compiler_status.ok();
68         if (!supported) {
69           LOG(INFO) << "platform " << platform->Name() << " present but no "
70                     << "XLA compiler available: "
71                     << compiler_status.status().error_message();
72         }
73         return supported;
74       });
75 }
76 
77 }  // namespace
78 
79 /* static */ StatusOr<std::vector<se::Platform*>>
GetSupportedPlatforms()80 PlatformUtil::GetSupportedPlatforms() {
81   // Gather all platforms which have an XLA compiler.
82   return xla::GetSupportedPlatforms();
83 }
84 
GetDefaultPlatform()85 /* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
86   TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
87 
88   se::Platform* platform = nullptr;
89   if (platforms.empty()) {
90     return NotFound("no platforms found");
91   } else if (platforms.size() == 1) {
92     platform = platforms[0];
93   } else if (platforms.size() == 2) {
94     for (int i = 0; i < 2; i++) {
95       if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter &&
96           absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) {
97         platform = platforms[1 - i];
98         break;
99       }
100     }
101   }
102   if (platform != nullptr) {
103     return platform;
104   }
105 
106   // Multiple platforms present and we can't pick a reasonable default.
107   std::string platforms_string = absl::StrJoin(
108       platforms, ", ",
109       [](std::string* out, const se::Platform* p) { out->append(p->Name()); });
110   return InvalidArgument(
111       "must specify platform because more than one platform (except for the "
112       "interpreter platform) found: %s.",
113       platforms_string);
114 }
115 
GetPlatform(const std::string & platform_name)116 /*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
117     const std::string& platform_name) {
118   TF_ASSIGN_OR_RETURN(se::Platform * platform,
119                       se::MultiPlatformManager::PlatformWithName(
120                           CanonicalPlatformName(platform_name)));
121   TF_RETURN_IF_ERROR(Compiler::GetForPlatform(platform).status());
122   return platform;
123 }
124 
125 // Returns whether the device underlying the given StreamExecutor is supported
126 // by XLA.
IsDeviceSupported(se::StreamExecutor * executor)127 static bool IsDeviceSupported(se::StreamExecutor* executor) {
128   const auto& description = executor->GetDeviceDescription();
129   if (executor->platform()->id() == se::cuda::kCudaPlatformId) {
130     // CUDA devices must have a minimum compute capability.
131     se::CudaComputeCapability cc = description.cuda_compute_capability();
132     if (!cc.IsAtLeast(kMinCudaComputeCapabilityMajor,
133                       kMinCudaComputeCapabilityMinor)) {
134       LOG(INFO) << "StreamExecutor cuda device (" << executor->device_ordinal()
135                 << ") is of "
136                 << "insufficient compute capability: "
137                 << kMinCudaComputeCapabilityMajor << "."
138                 << kMinCudaComputeCapabilityMinor << " required, "
139                 << "device is " << cc.ToString();
140       return false;
141     }
142   } else if (executor->platform()->id() == se::rocm::kROCmPlatformId) {
143     auto rocm_compute_capability = description.rocm_compute_capability();
144     if (!rocm_compute_capability.is_supported_gfx_version()) {
145       LOG(INFO) << "StreamExecutor ROCM device (" << executor->device_ordinal()
146                 << ") is of unsupported "
147                 << "AMDGPU version : " << rocm_compute_capability.gfx_version()
148                 << ". The supported AMDGPU versions are "
149                 << rocm_compute_capability.supported_gfx_versions_str() << ".";
150       return false;
151     }
152   }
153   return true;
154 }
155 
156 /* static */ StatusOr<std::vector<se::StreamExecutor*>>
GetStreamExecutors(se::Platform * platform,const std::optional<std::set<int>> & allowed_devices)157 PlatformUtil::GetStreamExecutors(
158     se::Platform* platform,
159     const std::optional<std::set<int>>& allowed_devices) {
160   int device_count = platform->VisibleDeviceCount();
161   if (device_count <= 0) {
162     return NotFound("no %s devices found", platform->Name());
163   }
164   if (platform->id() == se::host::kHostPlatformId) {
165     // On host "devices", StreamExecutor exports a device for each hardware
166     // thread. Because we parallelize a single computation across threads, it
167     // doesn't make sense to expose these as separate devices, so by default we
168     // fix the number of devices to one.  However we do let the user override
169     // this behavior to help run tests on the host that run models in parallel
170     // across multiple devices.
171     device_count =
172         GetDebugOptionsFromFlags().xla_force_host_platform_device_count();
173   }
174   std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
175   VLOG(1) << "Initializing devices";
176   {
177     tensorflow::thread::ThreadPool thread_pool(
178         tensorflow::Env::Default(), "device_initialization", device_count);
179     auto create_fn = [](se::Platform* platform,
180                         std::vector<se::StreamExecutor*>& stream_executors,
181                         int device_ordinal, int count) {
182       VLOG(1) << "Started device init " << device_ordinal;
183       auto executor_status = platform->ExecutorForDevice(device_ordinal);
184       if (executor_status.ok()) {
185         se::StreamExecutor* executor = executor_status.ValueOrDie();
186         if (IsDeviceSupported(executor)) {
187           stream_executors[count] = executor;
188         }
189       } else {
190         LOG(WARNING) << "unable to create StreamExecutor for "
191                      << platform->Name() << ":" << device_ordinal << ": "
192                      << executor_status.status().error_message();
193       }
194       VLOG(1) << "Finished device init " << device_ordinal;
195     };
196     // Once a stream executor is instantiated it will cause allocations on
197     // the device, for example for GPUs cuda context, cudnn handles etc. will
198     // be constructed. By constructing stream executors only on the
199     // allowed_devices, we don't make any allocations on other devices.
200     // This helps in multi-process executions on the same host like horovod or
201     // shared hosts.
202     if (allowed_devices) {
203       int count = 0;
204       for (const auto& i : *allowed_devices) {
205         if (count >= device_count) {
206           break;
207         }
208         thread_pool.Schedule(
209             [platform, &stream_executors, i, count, &create_fn]() {
210               create_fn(platform, stream_executors, i, count);
211             });
212         count++;
213       }
214     } else {
215       for (int i = 0; i < device_count; ++i) {
216         thread_pool.Schedule([platform, &stream_executors, i, &create_fn]() {
217           create_fn(platform, stream_executors, i, i);
218         });
219       }
220     }
221     // Block here in thread_pool destructor until all devices are initialized.
222   }
223   VLOG(1) << "Device initialization complete";
224 
225   std::vector<se::StreamExecutor*> out;
226   for (se::StreamExecutor* executor : stream_executors) {
227     if (executor != nullptr) {
228       out.push_back(executor);
229     }
230   }
231   if (out.empty()) {
232     return InternalError("no supported devices found for platform %s",
233                          platform->Name());
234   }
235   return out;
236 }
237 
238 }  // namespace xla
239