1 #include <ATen/core/interned_strings.h>
2 #include <ATen/core/jit_type.h>
3 #include <c10/core/Device.h>
4 #include <c10/util/ArrayRef.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/device_type_analysis.h>
8 #include <torch/csrc/jit/passes/shape_analysis.h>
9 #include <memory>
10 #include <optional>
11 #include <utility>
12
13 namespace torch::jit {
14
15 namespace {
16
17 using Tensor = at::Tensor;
18 using Device = at::Device;
19
20 using PropRule = std::function<bool(Node*)>;
21 /*
22 A Propagation Rule takes the Node, and
23 applies the relevant properties to the Tensor outputs
24 of the Node (based on the rule itself)
25
26 Returns: Bool indicating if anything was changed
27 */
28
setDeviceType(Value * value,std::optional<Device> device)29 bool setDeviceType(Value* value, std::optional<Device> device) {
30 auto tensor_type = value->type()->expect<TensorType>();
31 bool changed = tensor_type->device() != device;
32 if (changed) {
33 value->setType(tensor_type->withDevice(device));
34 }
35 return changed;
36 }
37
setReturnsToDevice(Node * n,std::optional<Device> device)38 bool setReturnsToDevice(Node* n, std::optional<Device> device) {
39 bool changed = false;
40 for (Value* out : n->outputs()) {
41 auto tensor_type = out->type()->cast<TensorType>();
42 if (!tensor_type) {
43 continue;
44 }
45 changed |= setDeviceType(out, device);
46 }
47 return changed;
48 }
49
setReturnstoDeviceRule(DeviceType deviceType)50 PropRule setReturnstoDeviceRule(DeviceType deviceType) {
51 Device device = Device(deviceType);
52 return [=](Node* n) { return setReturnsToDevice(n, device); };
53 }
54
returnFirstArgDeviceRule(Node * n)55 bool returnFirstArgDeviceRule(Node* n) {
56 // Custom Rule for when multiple args can have mismatched device types
57 auto tensor_type = n->inputs()[0]->type()->cast<TensorType>();
58 TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
59 return setReturnsToDevice(n, tensor_type->device());
60 }
61
returnSecondArgDeviceRule(Node * n)62 bool returnSecondArgDeviceRule(Node* n) {
63 // Custom Rule for when multiple args can have mismatched device types
64 auto tensor_type = n->inputs()[1]->type()->cast<TensorType>();
65 TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
66 return setReturnsToDevice(n, tensor_type->device());
67 }
68
isZerodimCPUTensor(const std::shared_ptr<TensorType> & tensor_type)69 bool isZerodimCPUTensor(const std::shared_ptr<TensorType>& tensor_type) {
70 // CPU devices on zerodim tensors are the only device that can be
71 // overwritten by another device. Therefore, to be conservative
72 // assume that it is not a zerodim cpu tensor if something is not known.
73 bool is_zerodim = tensor_type->symbolic_sizes().rank().value_or(-1) == 0;
74 bool is_cpu = tensor_type->device() && tensor_type->device()->is_cpu();
75 return is_zerodim && is_cpu;
76 }
77
propWithNoDevice(Node * n)78 bool propWithNoDevice(Node* n) {
79 // Propagate if we can verify that all input devices match,
80 // except CPU zerodim, which any other type can overwrite
81 size_t input_num = 0;
82
83 for (; input_num < n->inputs().size(); input_num++) {
84 if (n->inputs()[input_num]->type()->cast<TensorType>()) {
85 break;
86 }
87 }
88 if (input_num == n->inputs().size()) {
89 // No tensor found
90 return setReturnsToDevice(n, std::nullopt);
91 }
92
93 auto tensor_type = n->inputs()[input_num]->type()->expect<TensorType>();
94 bool only_seen_cpu_zerodim = isZerodimCPUTensor(tensor_type);
95 std::optional<Device> device = tensor_type->device();
96
97 // Now see if all inputs have a consistent device type
98 for (input_num++; input_num < n->inputs().size(); input_num++) {
99 auto tensor_type = n->inputs()[input_num]->type()->cast<TensorType>();
100 if (!tensor_type || isZerodimCPUTensor(tensor_type)) {
101 continue;
102 }
103
104 if (device != tensor_type->device()) {
105 if (only_seen_cpu_zerodim) {
106 device = tensor_type->device();
107 only_seen_cpu_zerodim = false;
108 } else {
109 // Bail on the type not match case
110 return setReturnsToDevice(n, std::nullopt);
111 }
112 }
113 }
114 return setReturnsToDevice(n, device);
115 }
116
defaultDeviceProp(Node * n)117 bool defaultDeviceProp(Node* n) {
118 // Detecting if the op has a device object argument
119 // as there is implicit string conversion to device
120 auto schema = n->maybeSchema();
121 if (!schema) {
122 return false;
123 }
124 auto arguments = schema->arguments();
125 for (size_t i = 0; i < arguments.size(); i++) {
126 Argument& argument = arguments[i];
127 if (DeviceObjType::get()->isSubtypeOf(argument.type())) {
128 // Optional args are filled in by torchscript with default val
129 auto input_val = toIValue(n->inputs().at(i));
130 if (!input_val.has_value()) {
131 // Can't propagate if there is a dynamic device type
132 return false;
133 }
134 if (input_val->isNone()) {
135 continue;
136 }
137 if (!input_val->isDevice()) {
138 // Bail on union types
139 return false;
140 }
141 TORCH_INTERNAL_ASSERT(input_val->isDevice())
142 Device device = input_val->toDevice();
143 return setReturnsToDevice(n, device);
144 }
145 }
146 return propWithNoDevice(n);
147 }
148
149 struct DeviceTypePropagationPass : public PropertyPropBase {
DeviceTypePropagationPasstorch::jit::__anona9b556810111::DeviceTypePropagationPass150 explicit DeviceTypePropagationPass(std::shared_ptr<Graph> graph)
151 : PropertyPropBase(std::move(graph)) {
152 buildRuleRegistry();
153 }
154
155 // returns true if at least one node has its scalar type set on a tensor node
runtorch::jit::__anona9b556810111::DeviceTypePropagationPass156 bool run() {
157 propagateBlock(graph_->block(), false);
158 return changed_;
159 }
160
161 private:
propagateNodetorch::jit::__anona9b556810111::DeviceTypePropagationPass162 void propagateNode(Node* n, bool _ = true) override {
163 GRAPH_DEBUG("processNode");
164 switch (n->kind()) {
165 case prim::If:
166 return processIf(n);
167 case prim::Loop:
168 return processLoop(n);
169 case prim::CallMethod:
170 case prim::CallFunction:
171 return; // Not handled for now
172 default:
173 break;
174 }
175
176 bool has_tensor_output =
177 std::any_of(n->outputs().begin(), n->outputs().end(), [](Value* v) {
178 return (bool)v->type()->cast<TensorType>();
179 });
180
181 if (!has_tensor_output) {
182 // if output contains no tensor, nothing to propagate
183 return;
184 }
185
186 switch (n->kind()) {
187 case prim::Constant:
188 // This is already been propagated by something else
189 case prim::ListConstruct:
190 case prim::ListUnpack:
191 return; // Not handled for now
192 default:
193 if (n->kind().is_aten()) {
194 return processAtenOps(n);
195 } else {
196 return; // Not handled for now
197 }
198 }
199 }
200
processAtenOpstorch::jit::__anona9b556810111::DeviceTypePropagationPass201 void processAtenOps(Node* n) {
202 GRAPH_DEBUG("processAtenOps");
203 GRAPH_DEBUG("case = ", n->kind(), " ", *n);
204 // Custom Rule Matching
205 auto op = n->maybeOperator();
206 if (!op) {
207 return;
208 }
209 auto prop_fn = device_prop_registry_->find(*op);
210 if (prop_fn) {
211 PropRule rule = *prop_fn;
212 changed_ |= rule(n);
213 return;
214 }
215 changed_ |= defaultDeviceProp(n);
216 }
217
buildRuleRegistrytorch::jit::__anona9b556810111::DeviceTypePropagationPass218 void buildRuleRegistry() {
219 // building a registry for all of the custom Device Type rules
220 if (device_prop_registry_)
221 return;
222
223 static OperatorMap<PropRule> temp_registry{
224 {"aten::cpu(Tensor self) -> Tensor",
225 setReturnstoDeviceRule(DeviceType::CPU)},
226 {"aten::cuda(Tensor self) -> Tensor",
227 setReturnstoDeviceRule(DeviceType::CUDA)},
228 {"aten::to_mkldnn(Tensor self, ScalarType? dtype) -> Tensor",
229 setReturnstoDeviceRule(DeviceType::MKLDNN)},
230 {"aten::reshape_as(Tensor self, Tensor other) -> Tensor",
231 returnFirstArgDeviceRule},
232 {"aten::view_as(Tensor self, Tensor other) -> Tensor",
233 returnFirstArgDeviceRule},
234 {"aten::expand_as(Tensor self, Tensor other) -> Tensor",
235 returnFirstArgDeviceRule},
236 {"aten::type_as(Tensor self, Tensor other) -> Tensor",
237 returnSecondArgDeviceRule},
238 };
239 device_prop_registry_ =
240 std::make_unique<OperatorMap<PropRule>>(std::move(temp_registry));
241 }
242
243 static std::unique_ptr<OperatorMap<PropRule>> device_prop_registry_;
244 bool changed_ = false;
245 };
246
247 std::unique_ptr<OperatorMap<PropRule>>
248 DeviceTypePropagationPass::device_prop_registry_ = nullptr;
249
250 } // anonymous namespace
251
252 // This analysis propagates input device types (if any) throughout the
253 // graph.
DeviceTypePropagation(std::shared_ptr<Graph> & graph)254 bool DeviceTypePropagation(std::shared_ptr<Graph>& graph) {
255 auto tp = std::make_unique<DeviceTypePropagationPass>((graph));
256 bool changed = tp->run();
257 if (changed) {
258 GRAPH_DUMP("After TensorPropertyPropagation pass:", graph);
259 }
260 return changed;
261 }
262
263 } // namespace torch::jit
264