xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/DMAConnectivity.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/DMAConnectivity.hpp>
2 
3 namespace {
4 
get_detector_key(c10::DeviceType device_type,std::string connection_type)5 std::string get_detector_key(
6     c10::DeviceType device_type,
7     std::string connection_type) {
8   std::ostringstream oss;
9   oss << device_type << "/" << connection_type;
10   return oss.str();
11 }
12 
13 class DetectorMap {
14  public:
get()15   static DetectorMap& get() {
16     static DetectorMap instance;
17     return instance;
18   }
19 
register_detector(c10::DeviceType device_type,const std::string & connection_type,c10::intrusive_ptr<c10d::DMAConnectivityDetector> detector)20   void register_detector(
21       c10::DeviceType device_type,
22       const std::string& connection_type,
23       c10::intrusive_ptr<c10d::DMAConnectivityDetector> detector) {
24     auto key = get_detector_key(device_type, connection_type);
25     detector_map_[key] = std::move(detector);
26   }
27 
detect(c10::DeviceType device_type,const std::string & connection_type)28   c10::intrusive_ptr<c10d::DMAConnectivity> detect(
29       c10::DeviceType device_type,
30       const std::string& connection_type) {
31     auto key = get_detector_key(device_type, connection_type);
32     {
33       auto it = cached_.find(key);
34       if (it != cached_.end()) {
35         return it->second;
36       }
37     }
38 
39     auto it = detector_map_.find(key);
40     TORCH_CHECK(
41         it != detector_map_.end(),
42         "DMA connectivity detector for ",
43         device_type,
44         " over ",
45         connection_type,
46         " is not available");
47     auto detector = it->second;
48     auto connectivity = detector->detect();
49     cached_[key] = connectivity;
50     return connectivity;
51   }
52 
53  private:
54   DetectorMap() = default;
55   DetectorMap(const DetectorMap&) = delete;
56   DetectorMap& operator=(const DetectorMap&) = delete;
57 
58   std::unordered_map<
59       std::string,
60       c10::intrusive_ptr<c10d::DMAConnectivityDetector>>
61       detector_map_;
62 
63   std::unordered_map<std::string, c10::intrusive_ptr<c10d::DMAConnectivity>>
64       cached_;
65 };
66 
67 }; // namespace
68 
69 namespace c10d {
70 
DMAConnectivity(c10::DeviceType device_type,std::string connection_type,std::vector<std::vector<int>> matrix)71 DMAConnectivity::DMAConnectivity(
72     c10::DeviceType device_type,
73     std::string connection_type,
74     std::vector<std::vector<int>> matrix)
75     : device_type(device_type),
76       connection_type(connection_type),
77       matrix(std::move(matrix)) {}
78 
register_dma_connectivity_detector(c10::DeviceType device_type,const std::string & connection_type,c10::intrusive_ptr<DMAConnectivityDetector> detector)79 void register_dma_connectivity_detector(
80     c10::DeviceType device_type,
81     const std::string& connection_type,
82     c10::intrusive_ptr<DMAConnectivityDetector> detector) {
83   return DetectorMap::get().register_detector(
84       device_type, connection_type, std::move(detector));
85 }
86 
detect_dma_connectivity(c10::DeviceType device_type,const std::string & connection_type)87 c10::intrusive_ptr<DMAConnectivity> detect_dma_connectivity(
88     c10::DeviceType device_type,
89     const std::string& connection_type) {
90   return DetectorMap::get().detect(device_type, connection_type);
91 }
92 
93 } // namespace c10d
94