xref: /aosp_15_r20/external/pytorch/c10/core/Device.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/Device.h>
2 #include <c10/util/Exception.h>
3 
4 #include <algorithm>
5 #include <array>
6 #include <cctype>
7 #include <exception>
8 #include <string>
9 #include <vector>
10 
11 namespace c10 {
12 namespace {
parse_type(const std::string & device_string)13 DeviceType parse_type(const std::string& device_string) {
14   static const std::array<
15       std::pair<const char*, DeviceType>,
16       static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
17       types = {{
18           {"cpu", DeviceType::CPU},
19           {"cuda", DeviceType::CUDA},
20           {"ipu", DeviceType::IPU},
21           {"xpu", DeviceType::XPU},
22           {"mkldnn", DeviceType::MKLDNN},
23           {"opengl", DeviceType::OPENGL},
24           {"opencl", DeviceType::OPENCL},
25           {"ideep", DeviceType::IDEEP},
26           {"hip", DeviceType::HIP},
27           {"ve", DeviceType::VE},
28           {"fpga", DeviceType::FPGA},
29           {"maia", DeviceType::MAIA},
30           {"xla", DeviceType::XLA},
31           {"lazy", DeviceType::Lazy},
32           {"vulkan", DeviceType::Vulkan},
33           {"mps", DeviceType::MPS},
34           {"meta", DeviceType::Meta},
35           {"hpu", DeviceType::HPU},
36           {"mtia", DeviceType::MTIA},
37           {"privateuseone", DeviceType::PrivateUse1},
38       }};
39   auto device = std::find_if(
40       types.begin(),
41       types.end(),
42       [&device_string](const std::pair<const char*, DeviceType>& p) {
43         return p.first && p.first == device_string;
44       });
45   if (device != types.end()) {
46     return device->second;
47   }
48   if (device_string == get_privateuse1_backend()) {
49     return DeviceType::PrivateUse1;
50   }
51   std::vector<const char*> device_names;
52   for (const auto& it : types) {
53     if (it.first) {
54       device_names.push_back(it.first);
55     }
56   }
57   TORCH_CHECK(
58       false,
59       "Expected one of ",
60       c10::Join(", ", device_names),
61       " device type at start of device string: ",
62       device_string);
63 }
64 enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };
65 
66 } // namespace
67 
Device(const std::string & device_string)68 Device::Device(const std::string& device_string) : Device(Type::CPU) {
69   TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
70 
71   std::string device_name, device_index_str;
72   DeviceStringParsingState pstate = DeviceStringParsingState::START;
73 
74   // The code below tries to match the string in the variable
75   // device_string against the regular expression:
76   // ([a-zA-Z_]+)(?::([1-9]\\d*|0))?
77   for (size_t i = 0;
78        pstate != DeviceStringParsingState::ERROR && i < device_string.size();
79        ++i) {
80     const char ch = device_string.at(i);
81     switch (pstate) {
82       case DeviceStringParsingState::START:
83         if (ch != ':') {
84           if (isalpha(ch) || ch == '_') {
85             device_name.push_back(ch);
86           } else {
87             pstate = DeviceStringParsingState::ERROR;
88           }
89         } else {
90           pstate = DeviceStringParsingState::INDEX_START;
91         }
92         break;
93 
94       case DeviceStringParsingState::INDEX_START:
95         if (isdigit(ch)) {
96           device_index_str.push_back(ch);
97           pstate = DeviceStringParsingState::INDEX_REST;
98         } else {
99           pstate = DeviceStringParsingState::ERROR;
100         }
101         break;
102 
103       case DeviceStringParsingState::INDEX_REST:
104         if (device_index_str.at(0) == '0') {
105           pstate = DeviceStringParsingState::ERROR;
106           break;
107         }
108         if (isdigit(ch)) {
109           device_index_str.push_back(ch);
110         } else {
111           pstate = DeviceStringParsingState::ERROR;
112         }
113         break;
114 
115       case DeviceStringParsingState::ERROR:
116         // Execution won't reach here.
117         break;
118     }
119   }
120 
121   const bool has_error = device_name.empty() ||
122       pstate == DeviceStringParsingState::ERROR ||
123       (pstate == DeviceStringParsingState::INDEX_START &&
124        device_index_str.empty());
125 
126   TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");
127 
128   try {
129     if (!device_index_str.empty()) {
130       index_ = static_cast<c10::DeviceIndex>(std::stoi(device_index_str));
131     }
132   } catch (const std::exception&) {
133     TORCH_CHECK(
134         false,
135         "Could not parse device index '",
136         device_index_str,
137         "' in device string '",
138         device_string,
139         "'");
140   }
141   type_ = parse_type(device_name);
142   validate();
143 }
144 
str() const145 std::string Device::str() const {
146   std::string str = DeviceTypeName(type(), /* lower case */ true);
147   if (has_index()) {
148     str.push_back(':');
149     str.append(std::to_string(index()));
150   }
151   return str;
152 }
153 
operator <<(std::ostream & stream,const Device & device)154 std::ostream& operator<<(std::ostream& stream, const Device& device) {
155   stream << device.str();
156   return stream;
157 }
158 
159 } // namespace c10
160