xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/utils/check_alias_annotation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
2 
3 #include <torch/csrc/jit/passes/constant_propagation.h>
4 #include <torch/csrc/jit/passes/normalize_ops.h>
5 #include <torch/csrc/jit/runtime/operator.h>
6 
7 #include <c10/util/irange.h>
8 
9 namespace torch {
10 namespace jit {
11 namespace {
12 
deepCopy(const IValue & self)13 IValue deepCopy(const IValue& self) {
14   // primitive types can be copied directly
15   if (!self.isPtrType()) {
16     return self;
17   }
18 
19   // Tensors need special handling, since copy assignment creates an alias
20   if (self.isTensor()) {
21     return IValue(self.toTensor().clone(at::MemoryFormat::Preserve));
22   }
23 
24   // Lists of ivalues should recursively deep copy their contents
25   if (self.isList()) {
26     auto source = self.toList();
27     auto newList = c10::impl::GenericList(source.elementType());
28     newList.reserve(source.size());
29     for (const IValue& value : source) {
30       newList.push_back(deepCopy(value));
31     }
32     return newList;
33   }
34 
35   // Regular lists can copy assign
36   if (self.isIntList()) {
37     return IValue(self.toIntList().copy());
38   } else if (self.isDoubleList()) {
39     return IValue(self.toDoubleList().copy());
40   } else if (self.isComplexDoubleList()) {
41     return IValue(self.toComplexDoubleList().copy());
42   } else if (self.isBoolList()) {
43     return IValue(self.toBoolList().copy());
44   } else if (self.isString()) {
45     return IValue(self.toStringRef());
46   }
47 
48   // If in the future we add more reference types that are used in aten ops,
49   // we'll have to add them as cases here.
50   AT_ASSERT(false);
51 }
52 
deepCopy(const Stack & stack)53 Stack deepCopy(const Stack& stack) {
54   Stack ret;
55   ret.reserve(stack.size());
56   for (const auto& v : stack) {
57     ret.push_back(deepCopy(v));
58   }
59   return ret;
60 }
61 
deepEquals(const IValue & lhs,const IValue & rhs)62 bool deepEquals(const IValue& lhs, const IValue& rhs) {
63   if (lhs.isTensor() && rhs.isTensor()) {
64     return lhs.toTensor().equal(rhs.toTensor());
65   }
66 
67   if (lhs.isTensorList() && rhs.isTensorList()) {
68     const auto a = lhs.toTensorList();
69     const auto b = rhs.toTensorList();
70     if (a.size() != b.size()) {
71       return false;
72     }
73     for (auto i = decltype(a.size()){0}; i < a.size(); ++i) {
74       if (!a[i].equal(b[i])) {
75         return false;
76       }
77     }
78     return true;
79   }
80 
81   return lhs == rhs;
82 }
83 
84 struct AliasAndIValue {
AliasAndIValuetorch::jit::__anon47b13e6e0111::AliasAndIValue85   AliasAndIValue(const at::AliasInfo* aliasInfo, IValue iValue)
86       : aliasInfo(aliasInfo), iValue(std::move(iValue)) {}
87 
88   const at::AliasInfo* aliasInfo;
89   const IValue iValue;
90 };
91 
92 // No inputs should alias each other
checkInputPreconditions(const Stack & inputs)93 void checkInputPreconditions(const Stack& inputs) {
94   for (const auto i : c10::irange(inputs.size())) {
95     for (const auto j : c10::irange(inputs.size())) {
96       if (i == j) {
97         continue;
98       }
99       const auto& lhs = inputs.at(i);
100       const auto& rhs = inputs.at(j);
101       AT_ASSERT(!lhs.isAliasOf(rhs));
102     }
103   }
104 }
105 
106 // If two ivalues alias, they must share an alias set
checkAliases(const std::vector<AliasAndIValue> & inputs,const std::vector<AliasAndIValue> & outputs)107 void checkAliases(
108     const std::vector<AliasAndIValue>& inputs,
109     const std::vector<AliasAndIValue>& outputs) {
110   for (const auto& output : outputs) {
111     // if this output aliases any input, make sure that they share an alias set
112     for (const auto& input : inputs) {
113       if (output.iValue.isAliasOf(input.iValue)) {
114         const auto* inputSet = input.aliasInfo;
115         const auto* outputSet = output.aliasInfo;
116         AT_ASSERT(inputSet && outputSet);
117         bool found = false;
118         for (const auto& set : inputSet->beforeSets()) {
119           if (outputSet->beforeSets().count(set)) {
120             found = true;
121             break;
122           }
123         }
124         AT_ASSERT(found);
125       }
126     }
127   }
128 }
129 
130 // If we didn't specify that we write to an input value, it must have not
131 // changed
checkWrites(const std::vector<AliasAndIValue> & inputs,const std::vector<IValue> & deepCopiedInputs)132 void checkWrites(
133     const std::vector<AliasAndIValue>& inputs,
134     const std::vector<IValue>& deepCopiedInputs) {
135   AT_ASSERT(inputs.size() == deepCopiedInputs.size());
136   for (const auto i : c10::irange(inputs.size())) {
137     const auto& input = inputs[i];
138     const auto& deepCopiedInput = deepCopiedInputs[i];
139     if (!input.aliasInfo || !input.aliasInfo->isWrite()) {
140       AT_ASSERT(deepEquals(input.iValue, deepCopiedInput));
141     }
142   }
143 }
144 
findNodeForOp(const Graph & g,const std::string & unqualifiedOpName)145 const Node* findNodeForOp(
146     const Graph& g,
147     const std::string& unqualifiedOpName) {
148   const auto opName = Symbol::fromQualString("aten::" + unqualifiedOpName);
149   for (const auto* node : g.nodes()) {
150     if (node->kind() == opName) {
151       return node;
152     }
153   }
154 
155   // Check for alias-ed operator names
156   const auto aliasOp = torch::jit::getOperatorAliasMap().find(opName);
157   if (aliasOp != torch::jit::getOperatorAliasMap().end()) {
158     for (const auto* node : g.nodes()) {
159       if (node->kind() == aliasOp->second) {
160         return node;
161       }
162     }
163   }
164 
165   // Ideally, there will be only one ATen operator that has tensor outputs in
166   // the graph. Let's use that as the last resolve to make checkAliasAnnotation
167   // more robust.
168   for (const auto* node : g.nodes()) {
169     if (!node->maybeOperator()) {
170       continue;
171     }
172     if (!node->getOperator().isC10Op()) {
173       continue;
174     }
175 
176     for (const auto* output : node->outputs()) {
177       if (output->type()->kind() == TypeKind::TensorType) {
178         return node;
179       }
180     }
181   }
182 
183   AT_ASSERT(false);
184 }
185 
186 // Handle a few special cases where we need to propagate constants
187 // manually
188 // TODO(suo): we should be able to move this stuff to constant prop
toIValueProp(const Value * v)189 std::optional<IValue> toIValueProp(const Value* v) {
190   if (v->node()->kind() == prim::ListConstruct) {
191     std::vector<IValue> genericList;
192     for (auto input : v->node()->inputs()) {
193       if (auto elem = toIValue(input)) {
194         genericList.push_back(*elem);
195       } else {
196         // One of the list elements isn't constant.
197         return std::nullopt;
198       }
199     }
200 
201     // Specialize the list based on ListConstruct's return type
202     auto listType = v->node()->output()->type();
203     auto containedType = listType->containedTypes().at(0);
204     if (containedType == IntType::get()) {
205       return IValue(
206           fmap(genericList, [](const IValue& v) { return v.toInt(); }));
207     } else if (containedType == FloatType::get()) {
208       return IValue(
209           fmap(genericList, [](const IValue& v) { return v.toDouble(); }));
210     } else if (containedType->isSubtypeOf(*TensorType::get())) {
211       return IValue(
212           fmap(genericList, [](const IValue& v) { return v.toTensor(); }));
213     } else {
214       return std::nullopt;
215     }
216   }
217 
218   if (v->node()->kind() == aten::Float) {
219     if (auto maybe_stack = runNodeIfInputsAreConstant(v->node())) {
220       return maybe_stack->at(0);
221     }
222   }
223   return std::nullopt;
224 }
225 
226 // batch_norm and instance_norm have incorrect annotations, because
227 // (a!)? annotations aren't supported, so these checks would fail.
228 // Their behavior also varies depending on the `training` and
229 // `use_input_stats` arguments.
230 // There are custom implementations in alias_analysis.cpp for these ops.
shouldIgnoreNode(const Node * n)231 bool shouldIgnoreNode(const Node* n) {
232   switch (n->kind()) {
233     case aten::batch_norm:
234     case aten::instance_norm:
235       return true;
236     default:
237       return false;
238   }
239 }
240 } // namespace
241 
checkAliasAnnotation(const std::shared_ptr<Graph> & graph,std::vector<IValue> pythonInputs,const std::string & unqualifiedOpName)242 void checkAliasAnnotation(
243     const std::shared_ptr<Graph>& graph,
244     std::vector<IValue> pythonInputs,
245     const std::string& unqualifiedOpName) {
246   // Find the node that corresponds to our op name
247   const auto node = findNodeForOp(*graph, unqualifiedOpName);
248   if (shouldIgnoreNode(node)) {
249     return;
250   }
251 
252   // Build the stack to use as input to the op
253   Stack stack;
254   for (const auto input : node->inputs()) {
255     if (input->node() == graph->param_node()) {
256       // This value was passed as an input in python
257       push(stack, pythonInputs.at(input->offset()));
258     } else {
259       // This a generated constant, which we need to evaluate
260       auto inputValue = toIValue(input);
261       if (!inputValue) {
262         inputValue = toIValueProp(input);
263       }
264 
265       if (inputValue) {
266         push(stack, *inputValue);
267       } else {
268         AT_ASSERT(input->type()->kind() == TypeKind::OptionalType);
269         push(stack, IValue());
270       }
271     }
272   }
273 
274   // Precondition: no inputs should alias each other. So if we find an alias,
275   // it was created by the op.
276   checkInputPreconditions(stack);
277 
278   const auto& schema = node->schema();
279 
280   std::vector<AliasAndIValue> inputsToCheck;
281   for (const auto i : c10::irange(schema.arguments().size())) {
282     inputsToCheck.emplace_back(
283         schema.arguments().at(i).alias_info(), stack.at(i));
284   }
285 
286   // Save a copy of the inputs so we can check whether the original inputs were
287   // written to.
288   const auto inputsDeepCopy = deepCopy(stack);
289 
290   // Run the op
291   node->getOperation()(stack);
292 
293   const auto outputs = std::move(stack);
294 
295   std::vector<AliasAndIValue> outputsToCheck;
296   for (const auto i : c10::irange(schema.returns().size())) {
297     outputsToCheck.emplace_back(
298         schema.returns().at(i).alias_info(), outputs.at(i));
299   }
300 
301   // Check that if any alias was created, we annotated it properly.
302   checkAliases(inputsToCheck, outputsToCheck);
303 
304   // Check that if nothing was accidentally written to.
305   checkWrites(inputsToCheck, inputsDeepCopy);
306 }
307 
308 } // namespace jit
309 } // namespace torch
310