xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/gpu/gpu_shared_context.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 // Copyright 2021 The TensorFlow Runtime Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef TENSORFLOW_CORE_TFRT_GPU_GPU_SHARED_CONTEXT_H_
16 #define TENSORFLOW_CORE_TFRT_GPU_GPU_SHARED_CONTEXT_H_
17 
18 #include <functional>
19 #include <string>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tfrt/support/error_util.h"  // from @tf_runtime
23 
24 namespace tfrt {
25 namespace gpu {
26 
27 // Key for naming up a particular XCCL clique.  This is just a set of unique
28 // device IDs (i.e. GPU IDs). The device IDs must be global within a collective.
29 using XcclCliqueKey = std::vector<int64_t>;
30 
31 // Callback that returns a ncclUniqueId encoded as a string for a group of
32 // communicating GPU devices.
33 using XcclUniqueIdCallback =
34     std::function<Expected<std::string>(const XcclCliqueKey&)>;
35 
36 // TODO(hanbinyoon): Rename this class appropriately.
37 // This class contains stateful resources needed to compile and execute programs
38 // in the XLA GPU integration environment.
39 class GpuSharedContext {
40  public:
41   // For BefThunk integration, this is the device ordinal.
42   typedef int LocalDeviceIdentifier;
43 
44   explicit GpuSharedContext(
45       int64_t run_id,
46       absl::flat_hash_map<LocalDeviceIdentifier, int> local_ids_to_rank,
47       std::vector<int64_t> gpu_global_device_ids,
48       XcclUniqueIdCallback xccl_unique_id_callback,
49       const std::string* compiled_code);
50 
51   // Accessors
run_id()52   int64_t run_id() const { return run_id_; }
local_ids_to_rank()53   const absl::flat_hash_map<LocalDeviceIdentifier, int>& local_ids_to_rank()
54       const {
55     return local_ids_to_rank_;
56   }
gpu_global_device_ids()57   const std::vector<int64_t>& gpu_global_device_ids() const {
58     return gpu_global_device_ids_;
59   }
xccl_unique_id_callback()60   const XcclUniqueIdCallback& xccl_unique_id_callback() const {
61     return xccl_unique_id_callback_;
62   }
compiled_code()63   const std::string* compiled_code() const { return compiled_code_; }
64 
65  private:
66   int64_t run_id_;
67   const absl::flat_hash_map<LocalDeviceIdentifier, int> local_ids_to_rank_;
68   const std::vector<int64_t>& gpu_global_device_ids_;
69   const XcclUniqueIdCallback xccl_unique_id_callback_;
70 
71   // The compiled code is PTX in Cuda and unused empty string in ROCm.
72   const std::string* compiled_code_;
73 };
74 
75 }  // namespace gpu
76 }  // namespace tfrt
77 
78 #endif  // TENSORFLOW_CORE_TFRT_GPU_GPU_SHARED_CONTEXT_H_
79