1 #include <c10/core/Device.h>
2 #include <c10/util/Exception.h>
3
4 #include <algorithm>
5 #include <array>
6 #include <cctype>
7 #include <exception>
8 #include <string>
9 #include <vector>
10
11 namespace c10 {
12 namespace {
parse_type(const std::string & device_string)13 DeviceType parse_type(const std::string& device_string) {
14 static const std::array<
15 std::pair<const char*, DeviceType>,
16 static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
17 types = {{
18 {"cpu", DeviceType::CPU},
19 {"cuda", DeviceType::CUDA},
20 {"ipu", DeviceType::IPU},
21 {"xpu", DeviceType::XPU},
22 {"mkldnn", DeviceType::MKLDNN},
23 {"opengl", DeviceType::OPENGL},
24 {"opencl", DeviceType::OPENCL},
25 {"ideep", DeviceType::IDEEP},
26 {"hip", DeviceType::HIP},
27 {"ve", DeviceType::VE},
28 {"fpga", DeviceType::FPGA},
29 {"maia", DeviceType::MAIA},
30 {"xla", DeviceType::XLA},
31 {"lazy", DeviceType::Lazy},
32 {"vulkan", DeviceType::Vulkan},
33 {"mps", DeviceType::MPS},
34 {"meta", DeviceType::Meta},
35 {"hpu", DeviceType::HPU},
36 {"mtia", DeviceType::MTIA},
37 {"privateuseone", DeviceType::PrivateUse1},
38 }};
39 auto device = std::find_if(
40 types.begin(),
41 types.end(),
42 [&device_string](const std::pair<const char*, DeviceType>& p) {
43 return p.first && p.first == device_string;
44 });
45 if (device != types.end()) {
46 return device->second;
47 }
48 if (device_string == get_privateuse1_backend()) {
49 return DeviceType::PrivateUse1;
50 }
51 std::vector<const char*> device_names;
52 for (const auto& it : types) {
53 if (it.first) {
54 device_names.push_back(it.first);
55 }
56 }
57 TORCH_CHECK(
58 false,
59 "Expected one of ",
60 c10::Join(", ", device_names),
61 " device type at start of device string: ",
62 device_string);
63 }
64 enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };
65
66 } // namespace
67
Device(const std::string & device_string)68 Device::Device(const std::string& device_string) : Device(Type::CPU) {
69 TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
70
71 std::string device_name, device_index_str;
72 DeviceStringParsingState pstate = DeviceStringParsingState::START;
73
74 // The code below tries to match the string in the variable
75 // device_string against the regular expression:
76 // ([a-zA-Z_]+)(?::([1-9]\\d*|0))?
77 for (size_t i = 0;
78 pstate != DeviceStringParsingState::ERROR && i < device_string.size();
79 ++i) {
80 const char ch = device_string.at(i);
81 switch (pstate) {
82 case DeviceStringParsingState::START:
83 if (ch != ':') {
84 if (isalpha(ch) || ch == '_') {
85 device_name.push_back(ch);
86 } else {
87 pstate = DeviceStringParsingState::ERROR;
88 }
89 } else {
90 pstate = DeviceStringParsingState::INDEX_START;
91 }
92 break;
93
94 case DeviceStringParsingState::INDEX_START:
95 if (isdigit(ch)) {
96 device_index_str.push_back(ch);
97 pstate = DeviceStringParsingState::INDEX_REST;
98 } else {
99 pstate = DeviceStringParsingState::ERROR;
100 }
101 break;
102
103 case DeviceStringParsingState::INDEX_REST:
104 if (device_index_str.at(0) == '0') {
105 pstate = DeviceStringParsingState::ERROR;
106 break;
107 }
108 if (isdigit(ch)) {
109 device_index_str.push_back(ch);
110 } else {
111 pstate = DeviceStringParsingState::ERROR;
112 }
113 break;
114
115 case DeviceStringParsingState::ERROR:
116 // Execution won't reach here.
117 break;
118 }
119 }
120
121 const bool has_error = device_name.empty() ||
122 pstate == DeviceStringParsingState::ERROR ||
123 (pstate == DeviceStringParsingState::INDEX_START &&
124 device_index_str.empty());
125
126 TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");
127
128 try {
129 if (!device_index_str.empty()) {
130 index_ = static_cast<c10::DeviceIndex>(std::stoi(device_index_str));
131 }
132 } catch (const std::exception&) {
133 TORCH_CHECK(
134 false,
135 "Could not parse device index '",
136 device_index_str,
137 "' in device string '",
138 device_string,
139 "'");
140 }
141 type_ = parse_type(device_name);
142 validate();
143 }
144
str() const145 std::string Device::str() const {
146 std::string str = DeviceTypeName(type(), /* lower case */ true);
147 if (has_index()) {
148 str.push_back(':');
149 str.append(std::to_string(index()));
150 }
151 return str;
152 }
153
operator <<(std::ostream & stream,const Device & device)154 std::ostream& operator<<(std::ostream& stream, const Device& device) {
155 stream << device.str();
156 return stream;
157 }
158
159 } // namespace c10
160