xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/device_type_analysis.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/interned_strings.h>
2 #include <ATen/core/jit_type.h>
3 #include <c10/core/Device.h>
4 #include <c10/util/ArrayRef.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/device_type_analysis.h>
8 #include <torch/csrc/jit/passes/shape_analysis.h>
9 #include <memory>
10 #include <optional>
11 #include <utility>
12 
13 namespace torch::jit {
14 
15 namespace {
16 
17 using Tensor = at::Tensor;
18 using Device = at::Device;
19 
20 using PropRule = std::function<bool(Node*)>;
21 /*
22 A Propagation Rule takes the Node, and
23 applies the relevant properties to the Tensor outputs
24 of the Node (based on the rule itself)
25 
26 Returns: Bool indicating if anything was changed
27 */
28 
setDeviceType(Value * value,std::optional<Device> device)29 bool setDeviceType(Value* value, std::optional<Device> device) {
30   auto tensor_type = value->type()->expect<TensorType>();
31   bool changed = tensor_type->device() != device;
32   if (changed) {
33     value->setType(tensor_type->withDevice(device));
34   }
35   return changed;
36 }
37 
setReturnsToDevice(Node * n,std::optional<Device> device)38 bool setReturnsToDevice(Node* n, std::optional<Device> device) {
39   bool changed = false;
40   for (Value* out : n->outputs()) {
41     auto tensor_type = out->type()->cast<TensorType>();
42     if (!tensor_type) {
43       continue;
44     }
45     changed |= setDeviceType(out, device);
46   }
47   return changed;
48 }
49 
setReturnstoDeviceRule(DeviceType deviceType)50 PropRule setReturnstoDeviceRule(DeviceType deviceType) {
51   Device device = Device(deviceType);
52   return [=](Node* n) { return setReturnsToDevice(n, device); };
53 }
54 
returnFirstArgDeviceRule(Node * n)55 bool returnFirstArgDeviceRule(Node* n) {
56   // Custom Rule for when multiple args can have mismatched device types
57   auto tensor_type = n->inputs()[0]->type()->cast<TensorType>();
58   TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
59   return setReturnsToDevice(n, tensor_type->device());
60 }
61 
returnSecondArgDeviceRule(Node * n)62 bool returnSecondArgDeviceRule(Node* n) {
63   // Custom Rule for when multiple args can have mismatched device types
64   auto tensor_type = n->inputs()[1]->type()->cast<TensorType>();
65   TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
66   return setReturnsToDevice(n, tensor_type->device());
67 }
68 
isZerodimCPUTensor(const std::shared_ptr<TensorType> & tensor_type)69 bool isZerodimCPUTensor(const std::shared_ptr<TensorType>& tensor_type) {
70   // CPU devices on zerodim tensors are the only device that can be
71   // overwritten by another device. Therefore, to be conservative
72   // assume that it is not a zerodim cpu tensor if something is not known.
73   bool is_zerodim = tensor_type->symbolic_sizes().rank().value_or(-1) == 0;
74   bool is_cpu = tensor_type->device() && tensor_type->device()->is_cpu();
75   return is_zerodim && is_cpu;
76 }
77 
propWithNoDevice(Node * n)78 bool propWithNoDevice(Node* n) {
79   // Propagate if we can verify that all input devices match,
80   // except CPU zerodim, which any other type can overwrite
81   size_t input_num = 0;
82 
83   for (; input_num < n->inputs().size(); input_num++) {
84     if (n->inputs()[input_num]->type()->cast<TensorType>()) {
85       break;
86     }
87   }
88   if (input_num == n->inputs().size()) {
89     // No tensor found
90     return setReturnsToDevice(n, std::nullopt);
91   }
92 
93   auto tensor_type = n->inputs()[input_num]->type()->expect<TensorType>();
94   bool only_seen_cpu_zerodim = isZerodimCPUTensor(tensor_type);
95   std::optional<Device> device = tensor_type->device();
96 
97   // Now see if all inputs have a consistent device type
98   for (input_num++; input_num < n->inputs().size(); input_num++) {
99     auto tensor_type = n->inputs()[input_num]->type()->cast<TensorType>();
100     if (!tensor_type || isZerodimCPUTensor(tensor_type)) {
101       continue;
102     }
103 
104     if (device != tensor_type->device()) {
105       if (only_seen_cpu_zerodim) {
106         device = tensor_type->device();
107         only_seen_cpu_zerodim = false;
108       } else {
109         // Bail on the type not match case
110         return setReturnsToDevice(n, std::nullopt);
111       }
112     }
113   }
114   return setReturnsToDevice(n, device);
115 }
116 
defaultDeviceProp(Node * n)117 bool defaultDeviceProp(Node* n) {
118   // Detecting if the op has a device object argument
119   // as there is implicit string conversion to device
120   auto schema = n->maybeSchema();
121   if (!schema) {
122     return false;
123   }
124   auto arguments = schema->arguments();
125   for (size_t i = 0; i < arguments.size(); i++) {
126     Argument& argument = arguments[i];
127     if (DeviceObjType::get()->isSubtypeOf(argument.type())) {
128       // Optional args are filled in by torchscript with default val
129       auto input_val = toIValue(n->inputs().at(i));
130       if (!input_val.has_value()) {
131         // Can't propagate if there is a dynamic device type
132         return false;
133       }
134       if (input_val->isNone()) {
135         continue;
136       }
137       if (!input_val->isDevice()) {
138         // Bail on union types
139         return false;
140       }
141       TORCH_INTERNAL_ASSERT(input_val->isDevice())
142       Device device = input_val->toDevice();
143       return setReturnsToDevice(n, device);
144     }
145   }
146   return propWithNoDevice(n);
147 }
148 
149 struct DeviceTypePropagationPass : public PropertyPropBase {
DeviceTypePropagationPasstorch::jit::__anona9b556810111::DeviceTypePropagationPass150   explicit DeviceTypePropagationPass(std::shared_ptr<Graph> graph)
151       : PropertyPropBase(std::move(graph)) {
152     buildRuleRegistry();
153   }
154 
155   // returns true if at least one node has its scalar type set on a tensor node
runtorch::jit::__anona9b556810111::DeviceTypePropagationPass156   bool run() {
157     propagateBlock(graph_->block(), false);
158     return changed_;
159   }
160 
161  private:
propagateNodetorch::jit::__anona9b556810111::DeviceTypePropagationPass162   void propagateNode(Node* n, bool _ = true) override {
163     GRAPH_DEBUG("processNode");
164     switch (n->kind()) {
165       case prim::If:
166         return processIf(n);
167       case prim::Loop:
168         return processLoop(n);
169       case prim::CallMethod:
170       case prim::CallFunction:
171         return; // Not handled for now
172       default:
173         break;
174     }
175 
176     bool has_tensor_output =
177         std::any_of(n->outputs().begin(), n->outputs().end(), [](Value* v) {
178           return (bool)v->type()->cast<TensorType>();
179         });
180 
181     if (!has_tensor_output) {
182       // if output contains no tensor, nothing to propagate
183       return;
184     }
185 
186     switch (n->kind()) {
187       case prim::Constant:
188         // This is already been propagated by something else
189       case prim::ListConstruct:
190       case prim::ListUnpack:
191         return; // Not handled for now
192       default:
193         if (n->kind().is_aten()) {
194           return processAtenOps(n);
195         } else {
196           return; // Not handled for now
197         }
198     }
199   }
200 
processAtenOpstorch::jit::__anona9b556810111::DeviceTypePropagationPass201   void processAtenOps(Node* n) {
202     GRAPH_DEBUG("processAtenOps");
203     GRAPH_DEBUG("case = ", n->kind(), " ", *n);
204     // Custom Rule Matching
205     auto op = n->maybeOperator();
206     if (!op) {
207       return;
208     }
209     auto prop_fn = device_prop_registry_->find(*op);
210     if (prop_fn) {
211       PropRule rule = *prop_fn;
212       changed_ |= rule(n);
213       return;
214     }
215     changed_ |= defaultDeviceProp(n);
216   }
217 
buildRuleRegistrytorch::jit::__anona9b556810111::DeviceTypePropagationPass218   void buildRuleRegistry() {
219     // building a registry for all of the custom Device Type rules
220     if (device_prop_registry_)
221       return;
222 
223     static OperatorMap<PropRule> temp_registry{
224         {"aten::cpu(Tensor self) -> Tensor",
225          setReturnstoDeviceRule(DeviceType::CPU)},
226         {"aten::cuda(Tensor self) -> Tensor",
227          setReturnstoDeviceRule(DeviceType::CUDA)},
228         {"aten::to_mkldnn(Tensor self, ScalarType? dtype) -> Tensor",
229          setReturnstoDeviceRule(DeviceType::MKLDNN)},
230         {"aten::reshape_as(Tensor self, Tensor other) -> Tensor",
231          returnFirstArgDeviceRule},
232         {"aten::view_as(Tensor self, Tensor other) -> Tensor",
233          returnFirstArgDeviceRule},
234         {"aten::expand_as(Tensor self, Tensor other) -> Tensor",
235          returnFirstArgDeviceRule},
236         {"aten::type_as(Tensor self, Tensor other) -> Tensor",
237          returnSecondArgDeviceRule},
238     };
239     device_prop_registry_ =
240         std::make_unique<OperatorMap<PropRule>>(std::move(temp_registry));
241   }
242 
243   static std::unique_ptr<OperatorMap<PropRule>> device_prop_registry_;
244   bool changed_ = false;
245 };
246 
247 std::unique_ptr<OperatorMap<PropRule>>
248     DeviceTypePropagationPass::device_prop_registry_ = nullptr;
249 
250 } // anonymous namespace
251 
252 // This analysis propagates input device types (if any) throughout the
253 // graph.
DeviceTypePropagation(std::shared_ptr<Graph> & graph)254 bool DeviceTypePropagation(std::shared_ptr<Graph>& graph) {
255   auto tp = std::make_unique<DeviceTypePropagationPass>((graph));
256   bool changed = tp->run();
257   if (changed) {
258     GRAPH_DUMP("After TensorPropertyPropagation pass:", graph);
259   }
260   return changed;
261 }
262 
263 } // namespace torch::jit
264