1 // Copyright 2021 The TensorFlow Runtime Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef TENSORFLOW_CORE_TFRT_GPU_GPU_SHARED_CONTEXT_H_ 16 #define TENSORFLOW_CORE_TFRT_GPU_GPU_SHARED_CONTEXT_H_ 17 18 #include <functional> 19 #include <string> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "tfrt/support/error_util.h" // from @tf_runtime 23 24 namespace tfrt { 25 namespace gpu { 26 27 // Key for naming up a particular XCCL clique. This is just a set of unique 28 // device IDs (i.e. GPU IDs). The device IDs must be global within a collective. 29 using XcclCliqueKey = std::vector<int64_t>; 30 31 // Callback that returns a ncclUniqueId encoded as a string for a group of 32 // communicating GPU devices. 33 using XcclUniqueIdCallback = 34 std::function<Expected<std::string>(const XcclCliqueKey&)>; 35 36 // TODO(hanbinyoon): Rename this class appropriately. 37 // This class contains stateful resources needed to compile and execute programs 38 // in the XLA GPU integration environment. 39 class GpuSharedContext { 40 public: 41 // For BefThunk integration, this is the device ordinal. 42 typedef int LocalDeviceIdentifier; 43 44 explicit GpuSharedContext( 45 int64_t run_id, 46 absl::flat_hash_map<LocalDeviceIdentifier, int> local_ids_to_rank, 47 std::vector<int64_t> gpu_global_device_ids, 48 XcclUniqueIdCallback xccl_unique_id_callback, 49 const std::string* compiled_code); 50 51 // Accessors run_id()52 int64_t run_id() const { return run_id_; } local_ids_to_rank()53 const absl::flat_hash_map<LocalDeviceIdentifier, int>& local_ids_to_rank() 54 const { 55 return local_ids_to_rank_; 56 } gpu_global_device_ids()57 const std::vector<int64_t>& gpu_global_device_ids() const { 58 return gpu_global_device_ids_; 59 } xccl_unique_id_callback()60 const XcclUniqueIdCallback& xccl_unique_id_callback() const { 61 return xccl_unique_id_callback_; 62 } compiled_code()63 const std::string* compiled_code() const { return compiled_code_; } 64 65 private: 66 int64_t run_id_; 67 const absl::flat_hash_map<LocalDeviceIdentifier, int> local_ids_to_rank_; 68 const std::vector<int64_t>& gpu_global_device_ids_; 69 const XcclUniqueIdCallback xccl_unique_id_callback_; 70 71 // The compiled code is PTX in Cuda and unused empty string in ROCm. 72 const std::string* compiled_code_; 73 }; 74 75 } // namespace gpu 76 } // namespace tfrt 77 78 #endif // TENSORFLOW_CORE_TFRT_GPU_GPU_SHARED_CONTEXT_H_ 79