1 #include <torch/csrc/distributed/c10d/DMAConnectivity.hpp>
2
3 namespace {
4
get_detector_key(c10::DeviceType device_type,std::string connection_type)5 std::string get_detector_key(
6 c10::DeviceType device_type,
7 std::string connection_type) {
8 std::ostringstream oss;
9 oss << device_type << "/" << connection_type;
10 return oss.str();
11 }
12
13 class DetectorMap {
14 public:
get()15 static DetectorMap& get() {
16 static DetectorMap instance;
17 return instance;
18 }
19
register_detector(c10::DeviceType device_type,const std::string & connection_type,c10::intrusive_ptr<c10d::DMAConnectivityDetector> detector)20 void register_detector(
21 c10::DeviceType device_type,
22 const std::string& connection_type,
23 c10::intrusive_ptr<c10d::DMAConnectivityDetector> detector) {
24 auto key = get_detector_key(device_type, connection_type);
25 detector_map_[key] = std::move(detector);
26 }
27
detect(c10::DeviceType device_type,const std::string & connection_type)28 c10::intrusive_ptr<c10d::DMAConnectivity> detect(
29 c10::DeviceType device_type,
30 const std::string& connection_type) {
31 auto key = get_detector_key(device_type, connection_type);
32 {
33 auto it = cached_.find(key);
34 if (it != cached_.end()) {
35 return it->second;
36 }
37 }
38
39 auto it = detector_map_.find(key);
40 TORCH_CHECK(
41 it != detector_map_.end(),
42 "DMA connectivity detector for ",
43 device_type,
44 " over ",
45 connection_type,
46 " is not available");
47 auto detector = it->second;
48 auto connectivity = detector->detect();
49 cached_[key] = connectivity;
50 return connectivity;
51 }
52
53 private:
54 DetectorMap() = default;
55 DetectorMap(const DetectorMap&) = delete;
56 DetectorMap& operator=(const DetectorMap&) = delete;
57
58 std::unordered_map<
59 std::string,
60 c10::intrusive_ptr<c10d::DMAConnectivityDetector>>
61 detector_map_;
62
63 std::unordered_map<std::string, c10::intrusive_ptr<c10d::DMAConnectivity>>
64 cached_;
65 };
66
67 }; // namespace
68
69 namespace c10d {
70
DMAConnectivity(c10::DeviceType device_type,std::string connection_type,std::vector<std::vector<int>> matrix)71 DMAConnectivity::DMAConnectivity(
72 c10::DeviceType device_type,
73 std::string connection_type,
74 std::vector<std::vector<int>> matrix)
75 : device_type(device_type),
76 connection_type(connection_type),
77 matrix(std::move(matrix)) {}
78
register_dma_connectivity_detector(c10::DeviceType device_type,const std::string & connection_type,c10::intrusive_ptr<DMAConnectivityDetector> detector)79 void register_dma_connectivity_detector(
80 c10::DeviceType device_type,
81 const std::string& connection_type,
82 c10::intrusive_ptr<DMAConnectivityDetector> detector) {
83 return DetectorMap::get().register_detector(
84 device_type, connection_type, std::move(detector));
85 }
86
detect_dma_connectivity(c10::DeviceType device_type,const std::string & connection_type)87 c10::intrusive_ptr<DMAConnectivity> detect_dma_connectivity(
88 c10::DeviceType device_type,
89 const std::string& connection_type) {
90 return DetectorMap::get().detect(device_type, connection_type);
91 }
92
93 } // namespace c10d
94